diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h')
-rw-r--r-- | tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h new file mode 100644 index 0000000000..18ff73ac39 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H + +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +// The base class for Cluster. A cluster is group of nodes all related to each +// other because their name match a given "pattern", which shows they all belong +// to a composite op supported in TFLite. The nodes in a cluster will be +// collapsed into a single composite op node plus a series of constant nodes +// holding the input parameters to that node. The nodes in a cluster are assumed +// to be using the same device. By changing the "pattern" we can have different +// subclasses of the base Cluster class. +class Cluster { + public: + virtual ~Cluster() {} + + virtual void CreateNodes() = 0; + + // Save the following info from the original GraphDef this cluster is from: + // 1- a pointer to the GraphDef + // 2- All the nodes in GraphDef which belong to this cluster. + void SetGraphDefInfo(const tensorflow::GraphDef* graph_def); + + const string& GetName() const { return name_; } + + const std::vector<std::unique_ptr<tensorflow::NodeDef>>& GetNewNodes() const { + return new_nodes_; + } + + const std::vector<const tensorflow::NodeDef*>& GetNodes() { return nodes_; } + + void SetName(const string& name) { name_ = name; } + + void SetDevice(const string& device) { device_ = device; } + + // Find the input(s) and output(s) of this Cluster. + bool FindClusterInputsAndOutputs(); + + protected: + string name_; + string device_; + std::vector<string> inputs_; + std::vector<string> outputs_; + + // Used to hold the pointers to nodes which are in this cluster. These nodes + // are pointing to the nodes in graph_def_. + std::vector<const tensorflow::NodeDef*> nodes_; + + // Used to cache the newly generated nodes: like the nodes created by + // collapsing Const nodes, or the nodes which is used to show the composite + // op. + std::vector<std::unique_ptr<tensorflow::NodeDef>> new_nodes_; + + const tensorflow::GraphDef* graph_def_; /*Not owned*/ +}; + +// A factory interface for cluster class. +// It defines a virtual function interface which is responsible for creating +// a cluster. Each cluster factory is responsible to pack a cluster of nodes +// into a cluster using a name-based pattern matching approach. +class ClusterFactoryInterface { + public: + virtual ~ClusterFactoryInterface() {} + + // Creates a cluster of nodes using a name-based pattern matching approach. It + // uses a node as a seed and if its name matches a certain pattern, then it + // builds the cluster around that node. + virtual std::unique_ptr<Cluster> CreateCluster( + const tensorflow::NodeDef& node, + const tensorflow::GraphDef& graph_def) const = 0; +}; + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H |