-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data loading is slow due to grabbing frames from huggingface dataset #93
Comments
Maybe useful snippets: In XarmDataset def __getitem__(self, idx):
start = time.perf_counter()
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
start_delta = time.perf_counter()
item, select_cols_time = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
self._delta_time += time.perf_counter() - start_delta
self._select_cols_time += select_cols_time
if self.transform is not None:
item = self.transform(item)
self._time += time.perf_counter() - start
return item You have to make start = time.perf_counter()
item[key] = hf_dataset.select_columns(key)[data_ids][key]
select_cols_time += time.perf_counter() - start |
dataset.select_columns
FYI, some things I tried: I tried doing slicing instead of advanced indexing but it didn't seem to help: if len(torch.unique(data_ids)) == len(data_ids):
item[key] = hf_dataset.select_columns(key)[data_ids.min().item():data_ids.max().item() + 1][key]
else:
item[key] = hf_dataset.select_columns(key)[data_ids][key] I tried using I tried using the |
In a more recent benchmark I found that even with the video datasets, more time is taken up accessing the hf_dataset, than decoding videos: from time import perf_counter
from tqdm import trange
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.video_utils import load_from_videos
fps = 15
horizon = 5
delta_timestamps = {
"observation.image": [i / fps for i in range(horizon + 1)],
"observation.state": [i / fps for i in range(horizon + 1)],
"action": [i / fps for i in range(horizon)],
"next.reward": [i / fps for i in range(horizon)],
}
dataset = LeRobotDataset("lerobot/xarm_lift_medium", delta_timestamps=delta_timestamps)
get_item_time = 0
delta_timestamps_time = 0
load_video_time = 0
transform_time = 0
def getitem(dataset: LeRobotDataset, idx: int) -> dict:
global get_item_time, delta_timestamps_time, load_video_time, transform_time
start = perf_counter()
item = dataset.hf_dataset[idx]
print("Get item:", (get_item_time := get_item_time + perf_counter() - start))
start = perf_counter()
if dataset.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
dataset.hf_dataset,
dataset.episode_data_index,
dataset.delta_timestamps,
dataset.tolerance_s,
)
print("Delta timestamps:", (delta_timestamps_time := delta_timestamps_time + perf_counter() - start))
start = perf_counter()
if dataset.video:
item = load_from_videos(
item,
dataset.video_frame_keys,
dataset.videos_dir,
dataset.tolerance_s,
)
print("Load video:", (load_video_time := load_video_time + perf_counter() - start))
start = perf_counter()
if dataset.transform is not None:
item = dataset.transform(item)
print("Transform:", (transform_time := transform_time + perf_counter() - start))
print()
return item
for idx in trange(len(dataset)):
getitem(dataset, idx) The last iteration prints:
|
I ran an experiment where I benchmarked times for running a dataloader on
xarm_lift_medium_replay
with batch size 256 and 0 workers.8 batches takes ~ 15s. Here's the breakdown:
dataset.__getitem__
: 14.707566491064426load_previous_and_future_frames
: 13.887129129978348hf_dataset.select_columns(key)[data_ids][key]
: 9.562710228981814The text was updated successfully, but these errors were encountered: