Skip to content

Commit

Permalink
fix bug in output format for pyav (#6672) (#6703)
Browse files Browse the repository at this point in the history
* fix bug in output format for pyav

* add read from memory with constructor overload

* Revert "add read from memory with constructor overload"

This reverts commit 14cbbab.

* run ufmt
  • Loading branch information
jdsgomes committed Oct 5, 2022
1 parent dc6d86d commit 8762598
Showing 1 changed file with 67 additions and 65 deletions.
132 changes: 67 additions & 65 deletions torchvision/io/video.py
Expand Up @@ -273,72 +273,74 @@ def read_video(
raise RuntimeError(f"File not found: {filename}")

if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

_check_av_available()

if end_pts is None:
end_pts = float("inf")

if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")

info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase

try:
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

except av.AVError:
# TODO raise a warning?
pass

vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]

if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
_check_av_available()

if end_pts is None:
end_pts = float("inf")

if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)

info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase

try:
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)

if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate

except av.AVError:
# TODO raise a warning?
pass

vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]

if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)

if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
Expand Down

0 comments on commit 8762598

Please sign in to comment.