diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-10-09 11:38:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 11:48:46 -0700 |
commit | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch) | |
tree | d2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h | |
parent | 931353c5f79c2d419afb3a5ecac59184c5558351 (diff) |
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 4d213b3f9c..a89db320ea 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -27,7 +27,8 @@ namespace toco { class GraphTransformation { public: - virtual bool Run(Model* model, std::size_t op_index) = 0; + virtual ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) = 0; virtual const char* Name() const = 0; virtual ~GraphTransformation() {} // Returns the list of messages that this graph transformation @@ -104,11 +105,12 @@ class GraphTransformationsSet { void RunGraphTransformations(Model* model, const string& message, const GraphTransformationsSet& transformations); -#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ - class GTName : public GraphTransformation { \ - public: \ - bool Run(Model* model, std::size_t op_index) override; \ - const char* Name() const override { return #GTName; } \ +#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ + class GTName : public GraphTransformation { \ + public: \ + ::tensorflow::Status Run(Model* model, std::size_t op_index, \ + bool* modified) override; \ + const char* Name() const override { return #GTName; } \ }; // List of all graph transformations @@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes) class PropagateDefaultMinMax : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "PropagateDefaultMinMax"; } bool has_any_ranges_defined() const { return !type_ranges_.empty(); } @@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation { class RemoveTrivialReshape : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "RemoveTrivialReshape"; } bool treat_expand_dims_as_trivial() const { return treat_expand_dims_as_trivial_; @@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation { class ResolveConstantFakeQuant : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "ResolveConstantFakeQuant"; } // True if the num_bits should adjust the final data type. @@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation { class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "EnsureUint8WeightsSafeForFastInt8Kernels"; } @@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { class IdentifyDilatedConv : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "IdentifyDilatedConv"; } bool identify_depthwise_conv() const { return identify_depthwise_conv_; } void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; } |