aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/ffmpeg/BUILD4
-rw-r--r--tensorflow/contrib/ffmpeg/decode_audio_op.cc230
-rw-r--r--tensorflow/contrib/ffmpeg/decode_audio_op_test.py89
-rw-r--r--tensorflow/contrib/ffmpeg/encode_audio_op.cc149
-rw-r--r--tensorflow/contrib/ffmpeg/encode_audio_op_test.py59
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py49
6 files changed, 477 insertions, 103 deletions
diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD
index e495ab4880..15224fbbcb 100644
--- a/tensorflow/contrib/ffmpeg/BUILD
+++ b/tensorflow/contrib/ffmpeg/BUILD
@@ -87,6 +87,8 @@ tf_py_test(
srcs = ["decode_audio_op_test.py"],
additional_deps = [
":ffmpeg_ops_py",
+ "@six_archive//:six",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
],
@@ -102,6 +104,8 @@ tf_py_test(
srcs = ["encode_audio_op_test.py"],
additional_deps = [
":ffmpeg_ops_py",
+ "@six_archive//:six",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
],
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
index a6a945094b..4b1c8a337e 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
@@ -60,8 +60,175 @@ class FileDeleter {
const string filename_;
};
+/*
+ * Decoding implementation, shared across V1 and V2 ops. Creates a new
+ * output in the context.
+ */
+void Decode(OpKernelContext* context,
+ const tensorflow::StringPiece& file_contents,
+ const string& file_format, const int32 samples_per_second,
+ const int32 channel_count) {
+ // Write the input data to a temp file.
+ const string temp_filename = GetTempFilename(file_format);
+ OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents));
+ FileDeleter deleter(temp_filename);
+
+ // Run FFmpeg on the data and verify results.
+ std::vector<float> output_samples;
+ Status result =
+ ffmpeg::ReadAudioFile(temp_filename, file_format, samples_per_second,
+ channel_count, &output_samples);
+ if (result.code() == error::Code::NOT_FOUND) {
+ OP_REQUIRES(
+ context, result.ok(),
+ errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
+ "can be found at http://www.ffmpeg.org."));
+ } else if (result.code() == error::UNKNOWN) {
+ LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
+ << "'. Returning empty tensor.";
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({0, 0}), &output));
+ return;
+ } else {
+ OP_REQUIRES_OK(context, result);
+ }
+ OP_REQUIRES(context, !output_samples.empty(),
+ errors::Unknown("No output created by FFmpeg."));
+ OP_REQUIRES(
+ context, output_samples.size() % channel_count == 0,
+ errors::Unknown("FFmpeg created non-integer number of audio frames."));
+
+ // Copy the output data to the output Tensor.
+ Tensor* output = nullptr;
+ const int64 frame_count = output_samples.size() / channel_count;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({frame_count, channel_count}), &output));
+ auto matrix = output->tensor<float, 2>();
+ for (int32 frame = 0; frame < frame_count; ++frame) {
+ for (int32 channel = 0; channel < channel_count; ++channel) {
+ matrix(frame, channel) = output_samples[frame * channel_count + channel];
+ }
+ }
+}
+
} // namespace
+/*
+ * Supersedes `DecodeAudioOp`. Allows all parameters to be inputs
+ * instead of attributes, so that they can be given as tensors rather
+ * than constants only.
+ */
+class DecodeAudioOpV2 : public OpKernel {
+ public:
+ explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES(
+ context, context->num_inputs() == 4,
+ errors::InvalidArgument("DecodeAudio requires exactly four inputs."));
+
+ const Tensor& contents_tensor = context->input(0);
+ const Tensor& file_format_tensor = context->input(1);
+ const Tensor& samples_per_second_tensor = context->input(2);
+ const Tensor& channel_count_tensor = context->input(3);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()),
+ errors::InvalidArgument(
+ "contents must be a rank-0 tensor but got shape ",
+ contents_tensor.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(file_format_tensor.shape()),
+ errors::InvalidArgument(
+ "file_format must be a rank-0 tensor but got shape ",
+ file_format_tensor.shape().DebugString()));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(samples_per_second_tensor.shape()),
+ errors::InvalidArgument(
+ "samples_per_second must be a rank-0 tensor but got shape ",
+ samples_per_second_tensor.shape().DebugString()));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(channel_count_tensor.shape()),
+ errors::InvalidArgument(
+ "channel_count must be a rank-0 tensor but got shape ",
+ channel_count_tensor.shape().DebugString()));
+
+ const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
+ const string file_format =
+ str_util::Lowercase(file_format_tensor.scalar<string>()());
+ const int32 samples_per_second =
+ samples_per_second_tensor.scalar<int32>()();
+ const int32 channel_count = channel_count_tensor.scalar<int32>()();
+
+ const std::set<string> valid_file_formats(
+ kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
+ OP_REQUIRES(
+ context, valid_file_formats.count(file_format) == 1,
+ errors::InvalidArgument("file_format must be one of {",
+ str_util::Join(valid_file_formats, ", "),
+ "}, but was: \"", file_format, "\""));
+ OP_REQUIRES(context, samples_per_second > 0,
+ errors::InvalidArgument(
+ "samples_per_second must be positive, but got: ",
+ samples_per_second));
+ OP_REQUIRES(
+ context, channel_count > 0,
+ errors::InvalidArgument("channel_count must be positive, but got: ",
+ channel_count));
+
+ Decode(context, contents, file_format, samples_per_second, channel_count);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodeAudioV2").Device(DEVICE_CPU),
+ DecodeAudioOpV2);
+
+REGISTER_OP("DecodeAudioV2")
+ .Input("contents: string")
+ .Input("file_format: string")
+ .Input("samples_per_second: int32")
+ .Input("channel_count: int32")
+ .Output("sampled_audio: float")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ const Tensor* channels_tensor = c->input_tensor(3);
+ if (channels_tensor == nullptr) {
+ c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
+ return Status::OK();
+ }
+ const int32 channels = channels_tensor->scalar<int32>()();
+ if (channels <= 0) {
+ return errors::InvalidArgument(
+ "channel_count must be positive, but got: ", channels);
+ }
+ c->set_output(0, c->Matrix(c->UnknownDim(), channels));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Processes the contents of an audio file into a tensor using FFmpeg to decode
+the file.
+
+One row of the tensor is created for each channel in the audio file. Each
+channel contains audio samples starting at the beginning of the audio and
+having `1/samples_per_second` time between them. If the `channel_count` is
+different from the contents of the file, channels will be merged or created.
+
+contents: The binary audio file contents, as a string or rank-0 string
+ tensor.
+file_format: A string or rank-0 string tensor describing the audio file
+ format. This must be one of: "mp3", "mp4", "ogg", "wav".
+samples_per_second: The number of samples per second that the audio
+ should have, as an `int` or rank-0 `int32` tensor. This value must
+ be positive.
+channel_count: The number of channels of audio to read, as an int rank-0
+ int32 tensor. Must be a positive integer.
+sampled_audio: A rank-2 tensor containing all tracks of the audio.
+ Dimension 0 is time and dimension 1 is the channel. If ffmpeg fails
+ to decode the audio then an empty tensor will be returned.
+)doc");
+
+/*
+ * Deprecated in favor of DecodeAudioOpV2.
+ */
class DecodeAudioOp : public OpKernel {
public:
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -69,15 +236,11 @@ class DecodeAudioOp : public OpKernel {
file_format_ = str_util::Lowercase(file_format_);
const std::set<string> valid_file_formats(
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
- OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
- errors::InvalidArgument(
- "file_format arg must be in {",
- str_util::Join(valid_file_formats, ", "), "}."));
-
- OP_REQUIRES_OK(
- context, context->GetAttr("samples_per_second", &samples_per_second_));
- OP_REQUIRES(context, samples_per_second_ > 0,
- errors::InvalidArgument("samples_per_second must be > 0."));
+ OP_REQUIRES(
+ context, valid_file_formats.count(file_format_) == 1,
+ errors::InvalidArgument("file_format must be one of {",
+ str_util::Join(valid_file_formats, ", "),
+ "}, but was: \"", file_format_, "\""));
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
OP_REQUIRES(context, channel_count_ > 0,
@@ -95,51 +258,9 @@ class DecodeAudioOp : public OpKernel {
errors::InvalidArgument("contents must be scalar but got shape ",
contents.shape().DebugString()));
- // Write the input data to a temp file.
const tensorflow::StringPiece file_contents = contents.scalar<string>()();
- const string input_filename = GetTempFilename(file_format_);
- OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents));
- FileDeleter deleter(input_filename);
-
- // Run FFmpeg on the data and verify results.
- std::vector<float> output_samples;
- Status result =
- ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_,
- channel_count_, &output_samples);
- if (result.code() == error::Code::NOT_FOUND) {
- OP_REQUIRES(
- context, result.ok(),
- errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
- "can be found at http://www.ffmpeg.org."));
- } else if (result.code() == error::UNKNOWN) {
- LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
- << "'. Returning empty tensor.";
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, TensorShape({0, 0}), &output));
- return;
- } else {
- OP_REQUIRES_OK(context, result);
- }
- OP_REQUIRES(context, !output_samples.empty(),
- errors::Unknown("No output created by FFmpeg."));
- OP_REQUIRES(
- context, output_samples.size() % channel_count_ == 0,
- errors::Unknown("FFmpeg created non-integer number of audio frames."));
-
- // Copy the output data to the output Tensor.
- Tensor* output = nullptr;
- const int64 frame_count = output_samples.size() / channel_count_;
- OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({frame_count, channel_count_}), &output));
- auto matrix = output->tensor<float, 2>();
- for (int32 frame = 0; frame < frame_count; ++frame) {
- for (int32 channel = 0; channel < channel_count_; ++channel) {
- matrix(frame, channel) =
- output_samples[frame * channel_count_ + channel];
- }
- }
+ Decode(context, file_contents, file_format_, samples_per_second_,
+ channel_count_);
}
private:
@@ -178,8 +299,7 @@ contents: The binary audio file contents.
sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
is time and dimension 1 is the channel. If ffmpeg fails to decode the audio
then an empty tensor will be returned.
-file_format: A string describing the audio file format. This can be "wav" or
- "mp3".
+file_format: A string describing the audio file format. This can be "mp3", "mp4", "ogg", or "wav".
samples_per_second: The number of samples per second that the audio should have.
channel_count: The number of channels of audio to read.
)doc");
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
index e4ed46b1e2..0d7c9cb99e 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
@@ -20,7 +20,11 @@ from __future__ import print_function
import os.path
+import six
+
from tensorflow.contrib import ffmpeg
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
@@ -28,7 +32,8 @@ from tensorflow.python.platform import test
class DecodeAudioOpTest(test.TestCase):
def _loadFileAndTest(self, filename, file_format, duration_sec,
- samples_per_second, channel_count):
+ samples_per_second, channel_count,
+ samples_per_second_tensor=None, feed_dict=None):
"""Loads an audio file and validates the output tensor.
Args:
@@ -37,7 +42,16 @@ class DecodeAudioOpTest(test.TestCase):
duration_sec: The duration of the audio contained in the file in seconds.
samples_per_second: The desired sample rate in the output tensor.
channel_count: The desired channel count in the output tensor.
+ samples_per_second_tensor: The value to pass to the corresponding
+ parameter in the instantiated `decode_audio` op. If not
+ provided, will default to a constant value of
+ `samples_per_second`. Useful for providing a placeholder.
+ feed_dict: Used when evaluating the `decode_audio` op. If not
+ provided, will be empty. Useful when providing a placeholder for
+ `samples_per_second_tensor`.
"""
+ if samples_per_second_tensor is None:
+ samples_per_second_tensor = samples_per_second
with self.test_session():
path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
filename)
@@ -47,9 +61,9 @@ class DecodeAudioOpTest(test.TestCase):
audio_op = ffmpeg.decode_audio(
contents,
file_format=file_format,
- samples_per_second=samples_per_second,
+ samples_per_second=samples_per_second_tensor,
channel_count=channel_count)
- audio = audio_op.eval()
+ audio = audio_op.eval(feed_dict=feed_dict or {})
self.assertEqual(len(audio.shape), 2)
self.assertNear(
duration_sec * samples_per_second,
@@ -104,6 +118,75 @@ class DecodeAudioOpTest(test.TestCase):
audio = audio_op.eval()
self.assertEqual(audio.shape, (0, 0))
+ def testSampleRatePlaceholder(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1,
+ samples_per_second_tensor=placeholder,
+ feed_dict={placeholder: 20000})
+
+ def testSampleRateBadType(self):
+ placeholder = array_ops.placeholder(dtypes.float32)
+ with self.assertRaises(TypeError):
+ self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
+ samples_per_second_tensor=placeholder,
+ feed_dict={placeholder: 20000.0})
+
+ def testSampleRateBadValue_Zero(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ with six.assertRaisesRegex(self, Exception,
+ r'samples_per_second must be positive'):
+ self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
+ samples_per_second_tensor=placeholder,
+ feed_dict={placeholder: 0})
+
+ def testSampleRateBadValue_Negative(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ with six.assertRaisesRegex(self, Exception,
+ r'samples_per_second must be positive'):
+ self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
+ samples_per_second_tensor=placeholder,
+ feed_dict={placeholder: -2})
+
+ def testInvalidFileFormat(self):
+ with six.assertRaisesRegex(self, Exception,
+ r'file_format must be one of'):
+ self._loadFileAndTest('mono_16khz.mp3', 'docx', 0.57, 20000, 1)
+
+ def testStaticShapeInference_ConstantChannelCount(self):
+ with self.test_session():
+ audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
+ file_format='wav',
+ samples_per_second=44100,
+ channel_count=2)
+ self.assertEqual([None, 2], audio_op.shape.as_list())
+
+ def testStaticShapeInference_NonConstantChannelCount(self):
+ with self.test_session():
+ channel_count = array_ops.placeholder(dtypes.int32)
+ audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
+ file_format='wav',
+ samples_per_second=44100,
+ channel_count=channel_count)
+ self.assertEqual([None, None], audio_op.shape.as_list())
+
+ def testStaticShapeInference_ZeroChannelCountInvalid(self):
+ with self.test_session():
+ with six.assertRaisesRegex(self, Exception,
+ r'channel_count must be positive'):
+ ffmpeg.decode_audio(b'~~~ wave ~~~',
+ file_format='wav',
+ samples_per_second=44100,
+ channel_count=0)
+
+ def testStaticShapeInference_NegativeChannelCountInvalid(self):
+ with self.test_session():
+ with six.assertRaisesRegex(self, Exception,
+ r'channel_count must be positive'):
+ ffmpeg.decode_audio(b'~~~ wave ~~~',
+ file_format='wav',
+ samples_per_second=44100,
+ channel_count=-2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
index bd3d6ae699..c00cccd846 100644
--- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
@@ -22,7 +22,137 @@
namespace tensorflow {
namespace ffmpeg {
+namespace {
+/*
+ * Encoding implementation, shared across V1 and V2 ops. Creates a new
+ * output in the context.
+ */
+void Encode(OpKernelContext* context, const Tensor& contents,
+ const string& file_format, const int32 bits_per_second,
+ const int32 samples_per_second) {
+ std::vector<float> samples;
+ samples.reserve(contents.NumElements());
+ for (int32 i = 0; i < contents.NumElements(); ++i) {
+ samples.push_back(contents.flat<float>()(i));
+ }
+ const int32 channel_count = contents.dim_size(1);
+ string encoded_audio;
+ OP_REQUIRES_OK(
+ context, CreateAudioFile(file_format, bits_per_second, samples_per_second,
+ channel_count, samples, &encoded_audio));
+
+ // Copy the encoded audio file to the output tensor.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output));
+ output->scalar<string>()() = encoded_audio;
+}
+
+} // namespace
+
+/*
+ * Supersedes `EncodeAudioOp`. Allows all parameters to be inputs
+ * instead of attributes, so that the sample rate (and, probably less
+ * usefully, the output file format) can be given as tensors rather than
+ * constants only.
+ */
+class EncodeAudioOpV2 : public OpKernel {
+ public:
+ explicit EncodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES(
+ context, context->num_inputs() == 4,
+ errors::InvalidArgument("EncodeAudio requires exactly four inputs."));
+
+ const Tensor& contents = context->input(0);
+ const Tensor& file_format_tensor = context->input(1);
+ const Tensor& samples_per_second_tensor = context->input(2);
+ const Tensor& bits_per_second_tensor = context->input(3);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
+ errors::InvalidArgument(
+ "sampled_audio must be a rank-2 tensor but got shape ",
+ contents.shape().DebugString()));
+ OP_REQUIRES(
+ context, contents.NumElements() <= std::numeric_limits<int32>::max(),
+ errors::InvalidArgument(
+ "sampled_audio cannot have more than 2^31 entries. Shape = ",
+ contents.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(file_format_tensor.shape()),
+ errors::InvalidArgument(
+ "file_format must be a rank-0 tensor but got shape ",
+ file_format_tensor.shape().DebugString()));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(samples_per_second_tensor.shape()),
+ errors::InvalidArgument(
+ "samples_per_second must be a rank-0 tensor but got shape ",
+ samples_per_second_tensor.shape().DebugString()));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(bits_per_second_tensor.shape()),
+ errors::InvalidArgument(
+ "bits_per_second must be a rank-0 tensor but got shape ",
+ bits_per_second_tensor.shape().DebugString()));
+
+ const string file_format =
+ str_util::Lowercase(file_format_tensor.scalar<string>()());
+ const int32 samples_per_second =
+ samples_per_second_tensor.scalar<int32>()();
+ const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()();
+
+ OP_REQUIRES(context, file_format == "wav",
+ errors::InvalidArgument(
+ "file_format must be \"wav\", but got: ", file_format));
+ OP_REQUIRES(context, samples_per_second > 0,
+ errors::InvalidArgument(
+ "samples_per_second must be positive, but got: ",
+ samples_per_second));
+ OP_REQUIRES(
+ context, bits_per_second > 0,
+ errors::InvalidArgument("bits_per_second must be positive, but got: ",
+ bits_per_second));
+
+ Encode(context, contents, file_format, bits_per_second, samples_per_second);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("EncodeAudioV2").Device(DEVICE_CPU),
+ EncodeAudioOpV2);
+
+REGISTER_OP("EncodeAudioV2")
+ .Input("sampled_audio: float")
+ .Input("file_format: string")
+ .Input("samples_per_second: int32")
+ .Input("bits_per_second: int32")
+ .Output("contents: string")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Processes a `Tensor` containing sampled audio with the number of channels
+and length of the audio specified by the dimensions of the `Tensor`. The
+audio is converted into a string that, when saved to disk, will be equivalent
+to the audio in the specified audio format.
+
+The input audio has one row of the tensor for each channel in the audio file.
+Each channel contains audio samples starting at the beginning of the audio and
+having `1/samples_per_second` time between them. The output file will contain
+all of the audio channels contained in the tensor.
+
+sampled_audio: A rank-2 float tensor containing all tracks of the audio.
+ Dimension 0 is time and dimension 1 is the channel.
+file_format: A string or rank-0 string tensor describing the audio file
+ format. This value must be `"wav"`.
+samples_per_second: The number of samples per second that the audio should
+ have, as an int or rank-0 `int32` tensor. This value must be
+ positive.
+bits_per_second: The approximate bitrate of the encoded audio file, as
+ an int or rank-0 `int32` tensor. This is ignored by the "wav" file
+ format.
+contents: The binary audio file contents, as a rank-0 string tensor.
+)doc");
+
+/*
+ * Deprecated in favor of EncodeAudioOpV2.
+ */
class EncodeAudioOp : public OpKernel {
public:
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -55,23 +185,8 @@ class EncodeAudioOp : public OpKernel {
"sampled_audio cannot have more than 2^31 entries. Shape = ",
contents.shape().DebugString()));
- // Create the encoded audio file.
- std::vector<float> samples;
- samples.reserve(contents.NumElements());
- for (int32 i = 0; i < contents.NumElements(); ++i) {
- samples.push_back(contents.flat<float>()(i));
- }
- const int32 channel_count = contents.dim_size(1);
- string encoded_audio;
- OP_REQUIRES_OK(context, CreateAudioFile(file_format_, bits_per_second_,
- samples_per_second_, channel_count,
- samples, &encoded_audio));
-
- // Copy the encoded audio file to the output tensor.
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, TensorShape(), &output));
- output->scalar<string>()() = encoded_audio;
+ Encode(context, contents, file_format_, bits_per_second_,
+ samples_per_second_);
}
private:
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
index 18d992911d..870290dc10 100644
--- a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
@@ -20,13 +20,24 @@ from __future__ import print_function
import os.path
+import six
+
from tensorflow.contrib import ffmpeg
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class EncodeAudioOpTest(test.TestCase):
+ def setUp(self):
+ super(EncodeAudioOpTest, self).setUp()
+ path = os.path.join(resource_loader.get_data_files_path(),
+ 'testdata/mono_10khz.wav')
+ with open(path, 'rb') as f:
+ self._contents = f.read()
+
def _compareWavFiles(self, original, encoded):
"""Compares the important bits of two WAV files.
@@ -51,20 +62,54 @@ class EncodeAudioOpTest(test.TestCase):
def testRoundTrip(self):
"""Reads a wav file, writes it, and compares them."""
with self.test_session():
- path = os.path.join(resource_loader.get_data_files_path(),
- 'testdata/mono_10khz.wav')
- with open(path, 'rb') as f:
- original_contents = f.read()
-
audio_op = ffmpeg.decode_audio(
- original_contents,
+ self._contents,
file_format='wav',
samples_per_second=10000,
channel_count=1)
encode_op = ffmpeg.encode_audio(
audio_op, file_format='wav', samples_per_second=10000)
encoded_contents = encode_op.eval()
- self._compareWavFiles(original_contents, encoded_contents)
+ self._compareWavFiles(self._contents, encoded_contents)
+
+ def testRoundTripWithPlaceholderSampleRate(self):
+ with self.test_session():
+ placeholder = array_ops.placeholder(dtypes.int32)
+ audio_op = ffmpeg.decode_audio(
+ self._contents,
+ file_format='wav',
+ samples_per_second=placeholder,
+ channel_count=1)
+ encode_op = ffmpeg.encode_audio(
+ audio_op, file_format='wav', samples_per_second=placeholder)
+ encoded_contents = encode_op.eval(feed_dict={placeholder: 10000})
+ self._compareWavFiles(self._contents, encoded_contents)
+
+ def testFloatingPointSampleRateInvalid(self):
+ with self.test_session():
+ with self.assertRaises(TypeError):
+ ffmpeg.encode_audio(
+ [[0.0], [1.0]],
+ file_format='wav',
+ samples_per_second=12345.678)
+
+ def testZeroSampleRateInvalid(self):
+ with self.test_session() as sess:
+ encode_op = ffmpeg.encode_audio(
+ [[0.0], [1.0]],
+ file_format='wav',
+ samples_per_second=0)
+ with six.assertRaisesRegex(self, Exception, 'must be positive'):
+ sess.run(encode_op)
+
+ def testNegativeSampleRateInvalid(self):
+ with self.test_session() as sess:
+ encode_op = ffmpeg.encode_audio(
+ [[0.0], [1.0]],
+ file_format='wav',
+ samples_per_second=-2)
+ with six.assertRaisesRegex(self, Exception, 'must be positive'):
+ sess.run(encode_op)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index 5f608cdb82..18b0b8b812 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -38,23 +38,26 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
Args:
contents: The binary contents of the audio file to decode. This is a
scalar.
- file_format: A string specifying which format the contents will conform
- to. This can be mp3, mp4, ogg, or wav.
- samples_per_second: The number of samples per second that is assumed.
- In some cases, resampling will occur to generate the correct sample
- rate.
+ file_format: A string or scalar string tensor specifying which
+ format the contents will conform to. This can be mp3, mp4, ogg,
+ or wav.
+ samples_per_second: The number of samples per second that is
+ assumed, as an `int` or scalar `int32` tensor. In some cases,
+ resampling will occur to generate the correct sample rate.
channel_count: The number of channels that should be created from the
- audio contents. If the contents have more than this number, then
- some channels will be merged or dropped. If contents has fewer than
- this, then additional channels will be created from the existing ones.
+ audio contents, as an `int` or scalar `int32` tensor. If the
+ `contents` have more than this number, then some channels will
+ be merged or dropped. If `contents` has fewer than this, then
+ additional channels will be created from the existing ones.
Returns:
- A rank 2 tensor that has time along dimension 0 and channels along
- dimension 1. Dimension 0 will be `samples_per_second * length` wide, and
- dimension 1 will be `channel_count` wide. If ffmpeg fails to decode the
- audio then an empty tensor will be returned.
+ A rank-2 tensor that has time along dimension 0 and channels along
+ dimension 1. Dimension 0 will be `samples_per_second *
+ length_in_seconds` wide, and dimension 1 will be `channel_count`
+ wide. If ffmpeg fails to decode the audio then an empty tensor will
+ be returned.
"""
- return gen_decode_audio_op_py.decode_audio(
+ return gen_decode_audio_op_py.decode_audio_v2(
contents, file_format=file_format, samples_per_second=samples_per_second,
channel_count=channel_count)
@@ -66,19 +69,23 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
"""Creates an op that encodes an audio file using sampled audio from a tensor.
Args:
- audio: A rank 2 tensor that has time along dimension 0 and channels along
- dimension 1. Dimension 0 is `samples_per_second * length` long in
- seconds.
- file_format: The type of file to encode. "wav" is the only supported format.
- samples_per_second: The number of samples in the audio tensor per second of
- audio.
+ audio: A rank-2 `Tensor` that has time along dimension 0 and
+ channels along dimension 1. Dimension 0 is `samples_per_second *
+ length_in_seconds` long.
+ file_format: The type of file to encode, as a string or rank-0
+ string tensor. "wav" is the only supported format.
+ samples_per_second: The number of samples in the audio tensor per
+ second of audio, as an `int` or rank-0 `int32` tensor.
Returns:
A scalar tensor that contains the encoded audio in the specified file
format.
"""
- return gen_encode_audio_op_py.encode_audio(
- audio, file_format=file_format, samples_per_second=samples_per_second)
+ return gen_encode_audio_op_py.encode_audio_v2(
+ audio,
+ file_format=file_format,
+ samples_per_second=samples_per_second,
+ bits_per_second=192000) # not used by WAV
ops.NotDifferentiable('EncodeAudio')