From cf1a5d335421d2a2c159ae2d44d3e54db7de004c Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 4 Oct 2022 17:48:11 +0100 Subject: [PATCH] fix bug in output format for pyav (#6672) * fix bug in output format for pyav * add read from memory with constructor overload * Revert "add read from memory with constructor overload" This reverts commit 14cbbab239165be05096fd6cbb88cb0448502436. * run ufmt --- torchvision/io/video.py | 132 ++++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 65 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index ceb20fe52c0..002fde9988c 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -273,72 +273,74 @@ def read_video( raise RuntimeError(f"File not found: {filename}") if get_video_backend() != "pyav": - return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) - - _check_av_available() - - if end_pts is None: - end_pts = float("inf") - - if end_pts < start_pts: - raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = _video_opt.default_timebase - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - - except av.AVError: - # TODO raise a warning? - pass - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) else: - aframes = torch.empty((1, 0), dtype=torch.float32) + _check_av_available() + + if end_pts is None: + end_pts = float("inf") + + if end_pts < start_pts: + raise ValueError( + f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" + ) + + info = {} + video_frames = [] + audio_frames = [] + audio_timebase = _video_opt.default_timebase + + try: + with av.open(filename, metadata_errors="ignore") as container: + if container.streams.audio: + audio_timebase = container.streams.audio[0].time_base + if container.streams.video: + video_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + + if container.streams.audio: + audio_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.audio[0], + {"audio": 0}, + ) + info["audio_fps"] = container.streams.audio[0].rate + + except av.AVError: + # TODO raise a warning? + pass + + vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes_list = [frame.to_ndarray() for frame in audio_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + + if aframes_list: + aframes = np.concatenate(aframes_list, 1) + aframes = torch.as_tensor(aframes) + if pts_unit == "sec": + start_pts = int(math.floor(start_pts * (1 / audio_timebase))) + if end_pts != float("inf"): + end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) if output_format == "TCHW": # [T,H,W,C] --> [T,C,H,W]