aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-06-20 18:36:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 18:39:22 -0700
commit96dfcc2fdc9f3a7419d3d5c5a64489e757de624e (patch)
tree8c684731bde1643158037bf1d4ed17e58c95096a
parente8b18a6f0c02d364ff47ba5fa3dc61458d273674 (diff)
Support filter format for FusedConv2DBiasActivation.
PiperOrigin-RevId: 201454730
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py20
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc93
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h10
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc119
4 files changed, 151 insertions, 91 deletions
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index a955e21b72..4d62ac65ff 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -35,13 +33,6 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
-def NoMemoryOptimizationConfig():
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.memory_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
-
def GetShrunkInceptionShapes(shrink=10):
"""Iterator for smaller versions of convolution shapes in 2015 Inception.
@@ -202,8 +193,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
# This is to guarantee that there is always negative values after
# bias add so that we can test whether relu works correctly.
x3 = bias
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()):
+ with self.test_session(use_gpu=True):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
fused_t2 = t2
@@ -251,9 +241,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
def _SetupVal(data_format, use_gpu):
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(
- use_gpu=use_gpu, config=NoMemoryOptimizationConfig()):
+ with self.test_session(use_gpu=use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes)
t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
@@ -877,9 +865,7 @@ class FusedConvInt8Tests(test.TestCase):
conv_input_scale, conv_input, kernel, padding_type, strides,
side_input_scale, side_input, biases)
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(
- use_gpu=True, config=NoMemoryOptimizationConfig()) as sess:
+ with self.test_session(use_gpu=True) as sess:
actual_y, expected_y = sess.run([actual, expected])
tf_logging.info("actual_y = ", actual_y)
tf_logging.info("expected_y = ", expected_y)
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b994d26397..d34eecd009 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -78,6 +78,14 @@ string GetDataFormat(const OpInfo& op_features) {
return data_format;
}
+string GetFilterFormat(const OpInfo& op_features) {
+ string filter_format = "HWIO"; // Default format.
+ if (op_features.attr().find("filter_format") != op_features.attr().end()) {
+ filter_format = op_features.attr().at("filter_format").s();
+ }
+ return filter_format;
+}
+
Padding GetPadding(const OpInfo& op_features) {
if (op_features.attr().find("padding") != op_features.attr().end() &&
op_features.attr().at("padding").s() == "VALID") {
@@ -513,29 +521,44 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
y_index = 3;
channel_index = 1;
} else {
+ // Use NHWC.
x_index = 1;
y_index = 2;
channel_index = 3;
}
+ const string& filter_format = GetFilterFormat(op_features);
+ int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
+ if (filter_format == "HWIO") {
+ filter_x_index = 0;
+ filter_y_index = 1;
+ in_channel_index = 2;
+ out_channel_index = 3;
+ } else {
+ // Use OIHW
+ filter_x_index = 2;
+ filter_y_index = 3;
+ in_channel_index = 1;
+ out_channel_index = 0;
+ }
int64 batch = image_shape.dim(0).size();
int64 ix = image_shape.dim(x_index).size();
int64 iy = image_shape.dim(y_index).size();
int64 iz = image_shape.dim(channel_index).size();
- int64 kx = filter_shape.dim(0).size();
- int64 ky = filter_shape.dim(1).size();
+ int64 kx = filter_shape.dim(filter_x_index).size();
+ int64 ky = filter_shape.dim(filter_y_index).size();
std::vector<int64> strides = GetStrides(op_features);
const auto padding = GetPadding(op_features);
int64 sx = strides[x_index];
int64 sy = strides[y_index];
int64 ox = GetOutputSize(ix, kx, sx, padding);
int64 oy = GetOutputSize(iy, ky, sy, padding);
- int64 oz = filter_shape.dim(3).size();
+ int64 oz = filter_shape.dim(out_channel_index).size();
// Only check equality when both sizes are known (in other words, when
// neither is set to a minimum dimension size of 1).
- if (iz != 1 && filter_shape.dim(2).size() != 1) {
- CHECK_EQ(iz, filter_shape.dim(2).size());
+ if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
+ CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
} else {
- iz = std::max<int64>(iz, filter_shape.dim(2).size());
+ iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
}
OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
@@ -1054,6 +1077,24 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
//
// For more information, see
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+
+ // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
+ // filter formats (OIHW_VECT_I).
+ string data_format = GetDataFormat(op_context.op_info);
+ if (data_format != "NCHW" && data_format != "NHWC") {
+ LOG(WARNING) << "unsupported data format: " << data_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+ string filter_format = GetFilterFormat(op_context.op_info);
+ if (filter_format != "HWIO" && filter_format != "OIHW") {
+ LOG(WARNING) << "unsupported filter format: " << filter_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+
auto& conv_input = op_context.op_info.inputs(0);
auto& filter = op_context.op_info.inputs(1);
auto& bias = op_context.op_info.inputs(2);
@@ -1069,28 +1110,12 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct the shape of our output tensor from our convolution dimensions
// and format, as it may not be available yet.
- //
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
- bool unknown_conv_format = false;
OpInfo::TensorProperties output;
- switch (GetConvolutionFormat(op_context)) {
- case NCHW:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
- break;
- case NHWC:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- break;
- default:
- // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
- LOG(WARNING) << "unsupported data format: "
- << GetDataFormat(op_context.op_info)
- << " Defaulting to NHWC.";
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- unknown_conv_format = true;
- break;
+ if (data_format == "NCHW") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ } else if (data_format == "NHWC") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
}
// Add the operations the fused op always computes.
@@ -1115,7 +1140,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct component operations and run the cost computation.
auto costs = PredictFusedOp(op_context_with_output, component_ops);
- costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ costs.inaccurate |= found_unknown_shapes;
return costs;
}
@@ -1568,20 +1593,6 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
}
/* static */
-OpLevelCostEstimator::ConvolutionFormat
-OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
- auto data_format = GetDataFormat(op_context.op_info);
- if (data_format == "NCHW") {
- return NCHW;
- } else if (data_format == "NHWC") {
- return NHWC;
- } else if (data_format == "NCHW_VECT_C") {
- return NCHW_VECT_C;
- }
-
- return UNKNOWN_CONVOLUTION_FORMAT;
-}
-
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index d384f57279..a277dfdf65 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -84,13 +84,6 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
- enum ConvolutionFormat {
- UNKNOWN_CONVOLUTION_FORMAT,
- NHWC,
- NCHW,
- NCHW_VECT_C,
- NCHW_VECT_W,
- };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -198,9 +191,6 @@ class OpLevelCostEstimator {
static OpInfo::TensorProperties DescribeTensor(
DataType type, const std::vector<int64>& dims);
- // Returns the Conv2D format for this operation.
- static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
-
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index b2c021b73a..77352f6652 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -155,19 +155,38 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
// Note that this assumes the NHWC data format.
OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
int iz2, int kx, int ky, int ox,
- int oy, int oz,
- bool has_side_input) {
+ int oy, int oz, bool has_side_input,
+ const string& data_format,
+ const string& filter_format) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("FusedConv2DBiasActivation");
- DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
- DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ auto* attr_data_format = op_context.op_info.mutable_attr();
+ SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
+ auto* attr_filter_format = op_context.op_info.mutable_attr();
+ SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ } else {
+ // Use the NCHW format.
+ DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
+ }
+ if (filter_format == "HWIO") {
+ DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ } else {
+ // Use the OIHW format.
+ DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
+ }
DescribeTensor1D(oz, op_context.op_info.add_inputs());
// Add the side_input, if any.
auto side_input = op_context.op_info.add_inputs();
if (has_side_input) {
- DescribeTensor4D(batch, ox, oy, oz, side_input);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ox, oy, oz, side_input);
+ } else {
+ DescribeTensor4D(batch, oz, ox, oy, side_input);
+ }
}
// Add the scaling tensors.
@@ -549,25 +568,79 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
SetComputeMemoryOverlap(false); // Set it back to default.
}
-TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest,
+ FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
+ "NCHW", "HWIO"));
+ EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "HWIO"));
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
-TEST_F(OpLevelCostEstimatorTest,
- FusedConv2DBiasActivationNoSideInputExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
- EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
- EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
- EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "HWIO"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once NCHW_VECT_C is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW_VECT_C", "OIHW"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once OIHW_VECT_I is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW_VECT_I"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
@@ -655,8 +728,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
TensorProto tensor_proto;
TensorShapeProto tensor_shape_proto;
- // Dimension larger than max value; should fail while converting to Tensor
- // class.
+ // Dimension larger than max value; should fail while converting to
+ // Tensor class.
tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
EXPECT_FALSE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
@@ -676,8 +749,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
// Check GetTensorShapeProtoFromTensorProto() resturns correct values.
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -685,8 +758,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -694,8 +767,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -703,8 +776,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);