diff options
author | Russell Power <power@google.com> | 2018-09-20 13:48:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 13:53:34 -0700 |
commit | d388770922ad1afa95e55597a33836fe74035c75 (patch) | |
tree | 3c15e17c357645e3e872ea46daaea2c91b00e9c1 /tensorflow/core/common_runtime | |
parent | 1f1e5ac6154583d5f87c846d1d7c9c59a77d6e0c (diff) |
Implement TF graph capture.
PiperOrigin-RevId: 213875284
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/session_ref.cc | 170 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/session_ref.h | 86 |
2 files changed, 0 insertions, 256 deletions
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc deleted file mode 100644 index b931ef4229..0000000000 --- a/tensorflow/core/common_runtime/session_ref.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/session_ref.h" - -#include <utility> - -namespace tensorflow { - -namespace { - -// Scope helper to track active calls and manage session lifetime. -struct RunCounter { - std::shared_ptr<Session> session; - uint64* value; - mutex* m; - condition_variable* cv; - - explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, - condition_variable* cv) - : session(std::move(s)), value(v), m(m), cv(cv) { - mutex_lock l(*m); - ++*value; - } - - ~RunCounter() { - mutex_lock l(*m); - if (--*value == 0) { - cv->notify_all(); - } - } -}; - -} // namespace - -Status SessionRef::CheckNotClosed() { - mutex_lock l(run_lock_); - if (session_ == nullptr) return errors::Cancelled("Session has been closed."); - return ::tensorflow::Status::OK(); -} - -Status SessionRef::Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(run_options, inputs, output_tensor_names, - target_node_names, outputs, run_metadata); -} - -Status SessionRef::Create(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(graph); -} - -Status SessionRef::Create(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(run_options, graph); -} - -Status SessionRef::Extend(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(run_options, graph); -} - -Status SessionRef::Extend(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(graph); -} - -Status SessionRef::Close(const RunOptions& run_options) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(run_options); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Close() { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(inputs, output_tensor_names, target_node_names, - outputs); -} - -Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ListDevices(response); -} - -Status SessionRef::PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRunSetup(input_names, output_names, target_nodes, handle); -} - -Status SessionRef::PRun(const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRun(handle, inputs, output_names, outputs); -} - -Status SessionRef::MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->MakeCallable(callable_options, out_handle); -} - -Status SessionRef::RunCallable(CallableHandle handle, - const std::vector<Tensor>& feed_tensors, - std::vector<Tensor>* fetch_tensors, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->RunCallable(handle, feed_tensors, fetch_tensors, - run_metadata); -} - -Status SessionRef::ReleaseCallable(CallableHandle handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ReleaseCallable(handle); -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h deleted file mode 100644 index 9459e7edbe..0000000000 --- a/tensorflow/core/common_runtime/session_ref.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ - -#include <memory> - -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { - -// A `SessionRef` manages the lifetime of a wrapped `Session` pointer. -// -// SessionRef blocks the return of Close() until all pending operations have -// been completed or cancelled and underlying session has been freed. Any -// subsequent operations on the SessionRef object will return errors::Cancelled. -class SessionRef : public Session { - public: - SessionRef(Session* session) : session_(session) {} - virtual ~SessionRef() {} - - Status Create(const GraphDef& graph) override; - Status Extend(const GraphDef& graph) override; - Status Create(const RunOptions& run_options, const GraphDef& graph) override; - Status Extend(const RunOptions& run_options, const GraphDef& graph) override; - Status Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs) override; - - Status ListDevices(std::vector<DeviceAttributes>* response) override; - - Status Close() override; - Status Close(const RunOptions& run_options) override; - - Status Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, RunMetadata* run_metadata) override; - - Status PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) override; - - Status PRun(const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) override; - - Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override; - - Status RunCallable(CallableHandle handle, - const std::vector<Tensor>& feed_tensors, - std::vector<Tensor>* fetch_tensors, - RunMetadata* run_metadata) override; - - Status ReleaseCallable(CallableHandle handle) override; - - private: - mutex run_lock_; - condition_variable run_finished_; - uint64 run_count_ GUARDED_BY(run_lock_) = {0}; - std::shared_ptr<Session> session_; - - Status CheckNotClosed(); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ |