aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/captured_function.h
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2017-12-14 16:05:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 16:09:19 -0800
commita5b2a0c9a3335d10c4dd3dfdff96149f74a4d120 (patch)
tree442f65bd185bfe40ee348003e305e6c4a445ea13 /tensorflow/core/kernels/captured_function.h
parent481b5f4410b34b65570f9dce62b34e9199769a38 (diff)
Moving tf.data kernels to their own package.
PiperOrigin-RevId: 179112798
Diffstat (limited to 'tensorflow/core/kernels/captured_function.h')
-rw-r--r--tensorflow/core/kernels/captured_function.h115
1 files changed, 4 insertions, 111 deletions
diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h
index c10472dde0..cdf191f4c7 100644
--- a/tensorflow/core/kernels/captured_function.h
+++ b/tensorflow/core/kernels/captured_function.h
@@ -12,116 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_KERNELS_CAPTURED_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_KERNELS_CAPTURED_FUNCTION_H_
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
-#include <memory>
-#include <vector>
+#include "tensorflow/core/kernels/data/captured_function.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace tensorflow {
-
-class Device;
-class OpKernelContext;
-class ResourceMgr;
-
-// A `CapturedFunction` encapsulates a TensorFlow function and all of
-// the runtime support required to execute it.
-//
-// The `Dataset`-related classes use `CapturedFunction` to execute
-// TensorFlow functions outside a the normal `OpKernel::Compute()`
-// context.
-//
-// NOTE(mrry): Here we are taking a conservative approach to dealing with
-// ownership of the various framework and runtime objects that are needed
-// to execute functions. We copy the function library *definition* (i.e.
-// a set of FunctionDefs) out of this kernel's context's function library
-// *runtime*, then we use that together with a specially-created
-// ThreadPoolDevice to build a new FunctionLibraryRuntime for the Dataset.
-//
-// We need to do this (or refactor the ownership of framework components
-// in each of the session implementations) to make it possible to close
-// down a ParallelMapDataset::Iterator when its session is closed.
-//
-// TODO(mrry): Clean this up. Investigate whether it would be possible to
-// reuse the session's FunctionLibraryRuntime(s) or Device(s).
-class CapturedFunction {
- public:
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(OpKernelContext* ctx, const NameAttrList& func,
- int graph_def_version,
- std::vector<Tensor> captured_inputs,
- std::unique_ptr<CapturedFunction>* out_function);
-
- // Synchronously runs the captured function on the given `args`, and stores
- // the results in `*rets`. This method takes ownership of the tensors in
- // `args`, in order to be able to deallocate them as early as possible.
- // Use `RunWithBorrowedArgs()` if the caller needs to retain ownership of
- // the `args`.
- Status Run(FunctionLibraryRuntime::Options f_opts, std::vector<Tensor>&& args,
- std::vector<Tensor>* rets);
-
- // Synchronously runs the captured function on the given `args`, and stores
- // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
- // possible.
- Status RunWithBorrowedArgs(FunctionLibraryRuntime::Options f_opts,
- const std::vector<Tensor>& args,
- std::vector<Tensor>* rets);
-
- // Asynchronously runs the captured function on the given `args`, stores
- // the results in `*rets`, and calls the given `done` callback when the
- // function returns. This method takes ownership of the tensors in `args`,
- // in order to be able to deallocate them as early as possible.
- void RunAsync(FunctionLibraryRuntime::Options f_opts,
- std::vector<Tensor>&& args, std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
-
- // Returns a borrowed pointer to the `ResourceManager` used when this
- // function is run.
- ResourceMgr* resource_manager() const { return device_->resource_manager(); }
-
- // Returns that additional captured inputs that will be passed to the function
- // when `Run*()` is called.
- const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
-
- // Returns a step ID for use when running a `CapturedFunction`.
- static int64 generate_step_id() {
- // Choose a step ID that is guaranteed not to clash with any
- // Session-generated step ID. DirectSession only generates
- // non-negative step IDs (contiguous, starting from 0), and
- // MasterSession generates 56-bit random step IDs whose MSB is
- // always 0, so a negative random step ID should suffice.
- return -std::abs(static_cast<int64>(random::New64()));
- }
-
- private:
- CapturedFunction(Device* device, std::unique_ptr<DeviceMgr> device_mgr,
- std::unique_ptr<FunctionLibraryDefinition> flib_def,
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- FunctionLibraryRuntime* lib,
- FunctionLibraryRuntime::Handle f_handle,
- std::vector<Tensor> captured_inputs,
- DataTypeSlice ret_types);
-
- Device* const device_; // owned by device_mgr_.
- const std::unique_ptr<DeviceMgr> device_mgr_;
- const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionLibraryRuntime* const lib_; // owned by pflr_.
- const FunctionLibraryRuntime::Handle f_handle_;
- const std::vector<Tensor> captured_inputs_;
- DataTypeSlice ret_types_; // owned by pflr_.
-
- TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
-};
-
-} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_KERNELS_CAPTURED_FUNCTION_H_
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_