Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data loading is slow due to grabbing frames from huggingface dataset #93

Open
alexander-soare opened this issue Apr 23, 2024 · 3 comments
Assignees
Labels
⚡️ Performance Performance-related

Comments

@alexander-soare
Copy link
Collaborator

alexander-soare commented Apr 23, 2024

I ran an experiment where I benchmarked times for running a dataloader on xarm_lift_medium_replay with batch size 256 and 0 workers.

8 batches takes ~ 15s. Here's the breakdown:

  • dataset.__getitem__: 14.707566491064426
  • load_previous_and_future_frames: 13.887129129978348
  • hf_dataset.select_columns(key)[data_ids][key]: 9.562710228981814
@alexander-soare
Copy link
Collaborator Author

alexander-soare commented Apr 23, 2024

Maybe useful snippets:

In XarmDataset

def __getitem__(self, idx):
    start = time.perf_counter()
    item = self.hf_dataset[idx]

    if self.delta_timestamps is not None:
        start_delta = time.perf_counter()
        item, select_cols_time = load_previous_and_future_frames(
            item,
            self.hf_dataset,
            self.episode_data_index,
            self.delta_timestamps,
            tol=1 / self.fps - 1e-4,  # 1e-4 to account for possible numerical error
        )
        self._delta_time += time.perf_counter() - start_delta

    self._select_cols_time += select_cols_time
    if self.transform is not None:
        item = self.transform(item)
    self._time += time.perf_counter() - start
    return item

You have to make load_previous_and_future_frames return the select_cols_time:

start = time.perf_counter()
item[key] = hf_dataset.select_columns(key)[data_ids][key]
select_cols_time += time.perf_counter() - start

@alexander-soare alexander-soare changed the title Data loading is slow due to dataset.select_columns Data loading is slow due to grabbing frames from dataset Apr 23, 2024
@alexander-soare alexander-soare changed the title Data loading is slow due to grabbing frames from dataset Data loading is slow due to grabbing frames from huggingface dataset Apr 23, 2024
@alexander-soare
Copy link
Collaborator Author

alexander-soare commented Apr 23, 2024

FYI, some things I tried:

I tried doing slicing instead of advanced indexing but it didn't seem to help:

if len(torch.unique(data_ids)) == len(data_ids):
    item[key] = hf_dataset.select_columns(key)[data_ids.min().item():data_ids.max().item() + 1][key]
else:
    item[key] = hf_dataset.select_columns(key)[data_ids][key]

I tried using select, and that didn't help either: item[key] = hf_dataset.select_columns(key).select(data_ids)[key]

I tried using the keep_in_memory option when loading the dataset and that didn't help.

@aliberts aliberts added the ⚡️ Performance Performance-related label Apr 29, 2024
@alexander-soare
Copy link
Collaborator Author

In a more recent benchmark I found that even with the video datasets, more time is taken up accessing the hf_dataset, than decoding videos:

from time import perf_counter

from tqdm import trange

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.video_utils import load_from_videos

fps = 15
horizon = 5

delta_timestamps = {
    "observation.image": [i / fps for i in range(horizon + 1)],
    "observation.state": [i / fps for i in range(horizon + 1)],
    "action": [i / fps for i in range(horizon)],
    "next.reward": [i / fps for i in range(horizon)],
}

dataset = LeRobotDataset("lerobot/xarm_lift_medium", delta_timestamps=delta_timestamps)

get_item_time = 0
delta_timestamps_time = 0
load_video_time = 0
transform_time = 0

def getitem(dataset: LeRobotDataset, idx: int) -> dict:
    global get_item_time, delta_timestamps_time, load_video_time, transform_time

    start = perf_counter()
    item = dataset.hf_dataset[idx]
    print("Get item:", (get_item_time := get_item_time + perf_counter() - start))

    start = perf_counter()
    if dataset.delta_timestamps is not None:
        item = load_previous_and_future_frames(
            item,
            dataset.hf_dataset,
            dataset.episode_data_index,
            dataset.delta_timestamps,
            dataset.tolerance_s,
        )
    print("Delta timestamps:", (delta_timestamps_time := delta_timestamps_time + perf_counter() - start))
    

    start = perf_counter()
    if dataset.video:
        item = load_from_videos(
            item,
            dataset.video_frame_keys,
            dataset.videos_dir,
            dataset.tolerance_s,
        )
    print("Load video:", (load_video_time := load_video_time + perf_counter() - start))

    start = perf_counter()
    if dataset.transform is not None:
        item = dataset.transform(item)
    print("Transform:", (transform_time := transform_time + perf_counter() - start))

    print()
    return item


for idx in trange(len(dataset)):
    getitem(dataset, idx)

The last iteration prints:

Get item: 3.056344897253439
Delta timestamps: 88.46708550985204
Load video: 65.91246535727987
Transform: 0.009780120220966637

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ Performance Performance-related
Projects
None yet
Development

No branches or pull requests

3 participants