diff options
author | 2017-07-18 15:20:03 -0700 | |
---|---|---|
committer | 2017-07-18 15:23:57 -0700 | |
commit | d5f4d9bbac520ad9eae6614fe678e9d1568435a4 (patch) | |
tree | 4d0e5115006e12127650b4f97d9b34942938abf3 /tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h | |
parent | d8672f1839df67f765baaa34a4c806ee1d433842 (diff) |
Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.
PiperOrigin-RevId: 162411763
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h')
-rw-r--r-- | tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h index a0df50162b..3fa052108e 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils { TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils); }; +class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor { + public: + TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types, + const string& executor_name); + + int GetVersion() final; + bool Init(const RemoteFusedGraphExecuteInfo&) final; + bool Finalize() final; + bool SetupGraph() final; + bool ExecuteGraph() final; + bool TeardownGraph() final; + bool FillInputNode(const string&, const Tensor&) final; + bool ReadOutputNode(const string&, TensorAllocatorFunc) final; + Status FuseRemoteGraph(const GraphDef& original_graph_def, + const std::vector<string>& inputs, + const std::vector<string>& outputs, + GraphDef* fused_graph_def) final; + bool IsEnabled() const final; + + private: + const std::unordered_set<string> fused_op_types_; + const string executor_name_; +}; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ |