From 23eff89e08c21abb3e5dde2fb5ebecabdcc64958 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 30 Sep 2022 13:58:08 +0100 Subject: [PATCH 1/4] fix bug in output format for pyav --- torchvision/io/video.py | 132 ++++++++++++++++++++-------------------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index ceb20fe52c0..ecf8ad5d0a0 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -273,72 +273,72 @@ def read_video( raise RuntimeError(f"File not found: {filename}") if get_video_backend() != "pyav": - return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) - - _check_av_available() - - if end_pts is None: - end_pts = float("inf") - - if end_pts < start_pts: - raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = _video_opt.default_timebase - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - - except av.AVError: - # TODO raise a warning? - pass - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) - else: - aframes = torch.empty((1, 0), dtype=torch.float32) + vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) + else: + _check_av_available() + + if end_pts is None: + end_pts = float("inf") + + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + + info = {} + video_frames = [] + audio_frames = [] + audio_timebase = _video_opt.default_timebase + + try: + with av.open(filename, metadata_errors="ignore") as container: + if container.streams.audio: + audio_timebase = container.streams.audio[0].time_base + if container.streams.video: + video_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + + if container.streams.audio: + audio_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.audio[0], + {"audio": 0}, + ) + info["audio_fps"] = container.streams.audio[0].rate + + except av.AVError: + # TODO raise a warning? + pass + + vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes_list = [frame.to_ndarray() for frame in audio_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + + if aframes_list: + aframes = np.concatenate(aframes_list, 1) + aframes = torch.as_tensor(aframes) + if pts_unit == "sec": + start_pts = int(math.floor(start_pts * (1 / audio_timebase))) + if end_pts != float("inf"): + end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) if output_format == "TCHW": # [T,H,W,C] --> [T,C,H,W] From 14cbbab239165be05096fd6cbb88cb0448502436 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Mon, 3 Oct 2022 15:38:08 +0100 Subject: [PATCH 2/4] add read from memory with constructor overload --- setup.py | 4 ++-- torchvision/csrc/io/video/video.cpp | 22 +++++++++++++--------- torchvision/csrc/io/video/video.h | 4 ++++ 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 25bef6b50de..130832abd2b 100644 --- a/setup.py +++ b/setup.py @@ -351,8 +351,8 @@ def get_extensions(): # FIXME: causes crash. See the following GitHub issues for more details. # FIXME: https://github.com/pytorch/pytorch/issues/65000 # FIXME: https://github.com/pytorch/vision/issues/3367 - if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): - has_ffmpeg = False + # if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): + # has_ffmpeg = False if has_ffmpeg: try: # This is to check if ffmpeg is installed properly. diff --git a/torchvision/csrc/io/video/video.cpp b/torchvision/csrc/io/video/video.cpp index 38b35014595..5157bca28fd 100644 --- a/torchvision/csrc/io/video/video.cpp +++ b/torchvision/csrc/io/video/video.cpp @@ -156,7 +156,7 @@ void Video::_getDecoderParams( } // _get decoder params -Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { +void Video::_init(std::string stream, int64_t numThreads) { C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video"); // set number of threads global numThreads_ = numThreads; @@ -173,13 +173,6 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { numThreads_ // global number of Threads for decoding ); - std::string logMessage, logType; - - // TODO: add read from memory option - params.uri = videoPath; - logType = "file"; - logMessage = videoPath; - // locals std::vector audioFPS, videoFPS; std::vector audioDuration, videoDuration, ccDuration, subsDuration; @@ -232,7 +225,18 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { << "Stream index set to " << std::get<1>(current_stream) << ". If you encounter trouble, consider switching it to automatic stream discovery. \n"; } -} // video +} // Video::Init + + +Video::Video(torch::Tensor videoData, std::string stream, int64_t numThreads) { + callback = MemoryBuffer::getCallback(videoData.data_ptr(), videoData.size(0)); + Video::_init(stream, numThreads); +} + +Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { + params.uri = videoPath; + Video::_init(stream, numThreads); +} bool Video::setCurrentStream(std::string stream = "video") { if ((!stream.empty()) && (_parseStream(stream) != current_stream)) { diff --git a/torchvision/csrc/io/video/video.h b/torchvision/csrc/io/video/video.h index 7cd926b793c..c9f03a84592 100644 --- a/torchvision/csrc/io/video/video.h +++ b/torchvision/csrc/io/video/video.h @@ -20,6 +20,8 @@ struct Video : torch::CustomClassHolder { public: Video(std::string videoPath, std::string stream, int64_t numThreads); + Video(torch::Tensor videoData, std::string stream, int64_t numThreads); + std::tuple getCurrentStream() const; c10::Dict>> getStreamMetadata() const; @@ -34,6 +36,8 @@ struct Video : torch::CustomClassHolder { // time in comination with any_frame settings double seekTS = -1; + void _init(std::string stream, int64_t numThreads); + void _getDecoderParams( double videoStartS, int64_t getPtsOnly, From 6ba592f964d887e173ab9c876d5aad67acde2f47 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Mon, 3 Oct 2022 15:42:06 +0100 Subject: [PATCH 3/4] Revert "add read from memory with constructor overload" This reverts commit 14cbbab239165be05096fd6cbb88cb0448502436. --- setup.py | 4 ++-- torchvision/csrc/io/video/video.cpp | 22 +++++++++------------- torchvision/csrc/io/video/video.h | 4 ---- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/setup.py b/setup.py index 130832abd2b..25bef6b50de 100644 --- a/setup.py +++ b/setup.py @@ -351,8 +351,8 @@ def get_extensions(): # FIXME: causes crash. See the following GitHub issues for more details. # FIXME: https://github.com/pytorch/pytorch/issues/65000 # FIXME: https://github.com/pytorch/vision/issues/3367 - # if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): - # has_ffmpeg = False + if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): + has_ffmpeg = False if has_ffmpeg: try: # This is to check if ffmpeg is installed properly. diff --git a/torchvision/csrc/io/video/video.cpp b/torchvision/csrc/io/video/video.cpp index 5157bca28fd..38b35014595 100644 --- a/torchvision/csrc/io/video/video.cpp +++ b/torchvision/csrc/io/video/video.cpp @@ -156,7 +156,7 @@ void Video::_getDecoderParams( } // _get decoder params -void Video::_init(std::string stream, int64_t numThreads) { +Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video"); // set number of threads global numThreads_ = numThreads; @@ -173,6 +173,13 @@ void Video::_init(std::string stream, int64_t numThreads) { numThreads_ // global number of Threads for decoding ); + std::string logMessage, logType; + + // TODO: add read from memory option + params.uri = videoPath; + logType = "file"; + logMessage = videoPath; + // locals std::vector audioFPS, videoFPS; std::vector audioDuration, videoDuration, ccDuration, subsDuration; @@ -225,18 +232,7 @@ void Video::_init(std::string stream, int64_t numThreads) { << "Stream index set to " << std::get<1>(current_stream) << ". If you encounter trouble, consider switching it to automatic stream discovery. \n"; } -} // Video::Init - - -Video::Video(torch::Tensor videoData, std::string stream, int64_t numThreads) { - callback = MemoryBuffer::getCallback(videoData.data_ptr(), videoData.size(0)); - Video::_init(stream, numThreads); -} - -Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { - params.uri = videoPath; - Video::_init(stream, numThreads); -} +} // video bool Video::setCurrentStream(std::string stream = "video") { if ((!stream.empty()) && (_parseStream(stream) != current_stream)) { diff --git a/torchvision/csrc/io/video/video.h b/torchvision/csrc/io/video/video.h index c9f03a84592..7cd926b793c 100644 --- a/torchvision/csrc/io/video/video.h +++ b/torchvision/csrc/io/video/video.h @@ -20,8 +20,6 @@ struct Video : torch::CustomClassHolder { public: Video(std::string videoPath, std::string stream, int64_t numThreads); - Video(torch::Tensor videoData, std::string stream, int64_t numThreads); - std::tuple getCurrentStream() const; c10::Dict>> getStreamMetadata() const; @@ -36,8 +34,6 @@ struct Video : torch::CustomClassHolder { // time in comination with any_frame settings double seekTS = -1; - void _init(std::string stream, int64_t numThreads); - void _getDecoderParams( double videoStartS, int64_t getPtsOnly, From e940ec74651bbda3c220a8127b48c4a8100db10d Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Mon, 3 Oct 2022 15:44:30 +0100 Subject: [PATCH 4/4] run ufmt --- torchvision/io/video.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index ecf8ad5d0a0..002fde9988c 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -274,14 +274,16 @@ def read_video( if get_video_backend() != "pyav": vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) - else: + else: _check_av_available() if end_pts is None: end_pts = float("inf") if end_pts < start_pts: - raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + raise ValueError( + f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" + ) info = {} video_frames = []