aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-05 08:06:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-05 09:11:43 -0700
commit931e848c28e97e8cae410af242f8e09d75663ee4 (patch)
treeaa8c70767b9cf78877436bce8424cdf851c2bb7d
parent78f134b6880d0f43ffd1d5049d504f2333e30681 (diff)
Adding an encode_audio op to the tensorflow/contrib/ffmpeg directory.
Change: 121584222
-rw-r--r--tensorflow/contrib/ffmpeg/BUILD49
-rw-r--r--tensorflow/contrib/ffmpeg/__init__.py1
-rw-r--r--tensorflow/contrib/ffmpeg/encode_audio_op.cc106
-rw-r--r--tensorflow/contrib/ffmpeg/encode_audio_op_test.py50
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py36
5 files changed, 238 insertions, 4 deletions
diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD
index 295b297455..3e8a1340a4 100644
--- a/tensorflow/contrib/ffmpeg/BUILD
+++ b/tensorflow/contrib/ffmpeg/BUILD
@@ -31,10 +31,25 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "encode_audio_op_cc",
+ srcs = ["encode_audio_op.cc"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:private"],
+ deps = [
+ "//third_party/eigen3",
+ "//tensorflow/contrib/ffmpeg/kernels:ffmpeg_lib",
+ "//tensorflow/core:framework_headers_lib",
+ ],
+ alwayslink = 1,
+)
+
tf_custom_op_library(
- name = "decode_audio_op.so",
+ name = "ffmpeg.so",
deps = [
":decode_audio_op_cc",
+ ":encode_audio_op_cc",
],
)
@@ -47,6 +62,15 @@ tf_gen_op_wrapper_py(
],
)
+tf_gen_op_wrapper_py(
+ name = "encode_audio_op_py",
+ require_shape_functions = True,
+ visibility = ["//visibility:private"],
+ deps = [
+ ":encode_audio_op_cc",
+ ],
+)
+
tf_py_test(
name = "decode_audio_op_test",
srcs = ["decode_audio_op_test.py"],
@@ -56,7 +80,25 @@ tf_py_test(
"//tensorflow/python:platform",
],
data = [
- ":decode_audio_op.so",
+ ":ffmpeg.so",
+ ":test_data",
+ ],
+ tags = [
+ "local",
+ "manual",
+ ],
+)
+
+tf_py_test(
+ name = "encode_audio_op_test",
+ srcs = ["encode_audio_op_test.py"],
+ additional_deps = [
+ ":ffmpeg_ops_py",
+ "//third_party/py/tensorflow",
+ "//tensorflow/python:platform",
+ ],
+ data = [
+ ":ffmpeg.so",
":test_data",
],
tags = [
@@ -72,11 +114,12 @@ py_library(
"ffmpeg_ops.py",
],
data = [
- ":decode_audio_op.so",
+ ":ffmpeg.so",
],
srcs_version = "PY2AND3",
deps = [
":decode_audio_op_py",
+ ":encode_audio_op_py",
],
)
diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py
index 31591df453..50c51d615b 100644
--- a/tensorflow/contrib/ffmpeg/__init__.py
+++ b/tensorflow/contrib/ffmpeg/__init__.py
@@ -19,3 +19,4 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio
+from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
new file mode 100644
index 0000000000..ed24e56532
--- /dev/null
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
@@ -0,0 +1,106 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include <limits>
+
+#include "tensorflow/contrib/ffmpeg/kernels/ffmpeg_lib.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace ffmpeg {
+
+class EncodeAudioOp : public OpKernel {
+ public:
+ explicit EncodeAudioOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
+ file_format_ = str_util::Lowercase(file_format_);
+ OP_REQUIRES(context, file_format_ == "wav",
+ errors::InvalidArgument("file_format arg must be \"wav\"."));
+
+ OP_REQUIRES_OK(
+ context, context->GetAttr("samples_per_second", &samples_per_second_));
+ OP_REQUIRES(context, samples_per_second_ > 0,
+ errors::InvalidArgument("samples_per_second must be > 0."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Get and verify the input data.
+ OP_REQUIRES(context, context->num_inputs() == 1,
+ errors::InvalidArgument(
+ "EncodeAudio requires exactly one input."));
+ const Tensor& contents = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
+ errors::InvalidArgument(
+ "sampled_audio must be a rank 2 tensor but got shape ",
+ contents.shape().DebugString()));
+ OP_REQUIRES(
+ context, contents.NumElements() <= std::numeric_limits<int32>::max(),
+ errors::InvalidArgument(
+ "sampled_audio cannot have more than 2^31 entries. Shape = ",
+ contents.shape().DebugString()));
+
+ // Create the encoded audio file.
+ std::vector<float> samples;
+ samples.reserve(contents.NumElements());
+ for (int32 i = 0; i < contents.NumElements(); ++i) {
+ samples.push_back(contents.flat<float>()(i));
+ }
+ const int32 channel_count = contents.dim_size(1);
+ string encoded_audio;
+ OP_REQUIRES_OK(context,
+ CreateAudioFile(file_format_, samples_per_second_,
+ channel_count, samples, &encoded_audio));
+
+ // Copy the encoded audio file to the output tensor.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape(), &output));
+ output->scalar<string>()() = encoded_audio;
+ }
+
+ private:
+ string file_format_;
+ int32 samples_per_second_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("EncodeAudio").Device(DEVICE_CPU), EncodeAudioOp);
+
+REGISTER_OP("EncodeAudio")
+ .Input("sampled_audio: float")
+ .Output("contents: string")
+ .Attr("file_format: string")
+ .Attr("samples_per_second: int")
+ .Doc(R"doc(
+Processes a `Tensor` containing sampled audio with the number of channels
+and length of the audio specified by the dimensions of the `Tensor`. The
+audio is converted into a string that, when saved to disk, will be equivalent
+to the audio in the specified audio format.
+
+The input audio has one row of the tensor for each channel in the audio file.
+Each channel contains audio samples starting at the beginning of the audio and
+having `1/samples_per_second` time between them. The output file will contain
+all of the audio channels contained in the tensor.
+
+sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
+ is time and dimension 1 is the channel.
+contents: The binary audio file contents.
+file_format: A string describing the audio file format. This must be "wav".
+samples_per_second: The number of samples per second that the audio should have.
+)doc");
+
+} // namespace ffmpeg
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
new file mode 100644
index 0000000000..74f05917a6
--- /dev/null
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
@@ -0,0 +1,50 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for third_party.tensorflow.contrib.ffmpeg.encode_audio_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+from tensorflow.contrib import ffmpeg
+from tensorflow.python.platform import resource_loader
+
+
+class EncodeAudioOpTest(tf.test.TestCase):
+
+ def testRoundTrip(self):
+ """Fabricates some audio, creates a wav file, reverses it, and compares."""
+ with self.test_session():
+ path = os.path.join(
+ resource_loader.get_data_files_path(), 'testdata/mono_10khz.wav')
+ with open(path, 'r') as f:
+ original_contents = f.read()
+
+ audio_op = ffmpeg.decode_audio(
+ original_contents, file_format='wav', samples_per_second=10000,
+ channel_count=1)
+ encode_op = ffmpeg.encode_audio(
+ audio_op, file_format='wav', samples_per_second=10000)
+ encoded_contents = encode_op.eval()
+ self.assertEqual(original_contents, encoded_contents)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index d443b24ba8..cdb8006e6b 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py
+from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -75,6 +76,39 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
ops.NoGradient('DecodeAudio')
+@ops.RegisterShape('EncodeAudio')
+def _encode_audio_shape(unused_op):
+ """Computes the shape of an EncodeAudio operation.
+
+ Returns:
+ A list of output shapes. There's exactly one output, the formatted audio
+ file. This is a rank 0 tensor.
+ """
+ return [tensor_shape.TensorShape([])]
+
+
+def encode_audio(audio, file_format=None, samples_per_second=None):
+ """Creates an op that encodes an audio file using sampled audio from a tensor.
+
+ Args:
+ audio: A rank 2 tensor that has time along dimension 0 and channels along
+ dimension 1. Dimension 0 is `samples_per_second * length` long in
+ seconds.
+ file_format: The type of file to encode. "wav" is the only supported format.
+ samples_per_second: The number of samples in the audio tensor per second of
+ audio.
+
+ Returns:
+ A scalar tensor that contains the encoded audio in the specified file
+ format.
+ """
+ return gen_encode_audio_op_py.encode_audio(
+ audio, file_format=file_format, samples_per_second=samples_per_second)
+
+
+ops.NoGradient('EncodeAudio')
+
+
def _load_library(name, op_list=None):
"""Loads a .so file containing the specified operators.
@@ -97,4 +131,4 @@ def _load_library(name, op_list=None):
(expected_op, name))
-_load_library('decode_audio_op.so', ['DecodeAudio'])
+_load_library('ffmpeg.so', ['DecodeAudio', 'EncodeAudio'])