aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h29
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 4d213b3f9c..a89db320ea 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -27,7 +27,8 @@ namespace toco {
class GraphTransformation {
public:
- virtual bool Run(Model* model, std::size_t op_index) = 0;
+ virtual ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) = 0;
virtual const char* Name() const = 0;
virtual ~GraphTransformation() {}
// Returns the list of messages that this graph transformation
@@ -104,11 +105,12 @@ class GraphTransformationsSet {
void RunGraphTransformations(Model* model, const string& message,
const GraphTransformationsSet& transformations);
-#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
- class GTName : public GraphTransformation { \
- public: \
- bool Run(Model* model, std::size_t op_index) override; \
- const char* Name() const override { return #GTName; } \
+#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
+ class GTName : public GraphTransformation { \
+ public: \
+ ::tensorflow::Status Run(Model* model, std::size_t op_index, \
+ bool* modified) override; \
+ const char* Name() const override { return #GTName; } \
};
// List of all graph transformations
@@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
class PropagateDefaultMinMax : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "PropagateDefaultMinMax"; }
bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
@@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation {
class RemoveTrivialReshape : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "RemoveTrivialReshape"; }
bool treat_expand_dims_as_trivial() const {
return treat_expand_dims_as_trivial_;
@@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation {
class ResolveConstantFakeQuant : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "ResolveConstantFakeQuant"; }
// True if the num_bits should adjust the final data type.
@@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation {
class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override {
return "EnsureUint8WeightsSafeForFastInt8Kernels";
}
@@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
class IdentifyDilatedConv : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "IdentifyDilatedConv"; }
bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }