Dataset loading and transformation infrastructure for 3D shape completion. Provides a unified interface for loading diverse 3D datasets with configurable data fields and augmentation pipelines.
# As submodule
# Dependencies (from main repo)
uv sync --extra dataset
from dataset import get_dataset, get_transformations
# Using factory (recommended)
datasets = get_dataset(cfg, splits=("train", "val"))
train_dataset = datasets["train"]
val_dataset = datasets["val"]
# With transforms
transforms = get_transformations(cfg, split="train")
dataset/
├── __init__.py # get_dataset(), get_transformations() factories
├── src/
│ ├── __init__.py # Public exports
│ ├── fields.py # Field classes for data loading (15 fields)
│ ├── transforms.py # Transform classes for augmentation (59 transforms)
│ ├── shared.py # SharedDataset, SharedDataLoader (shared-memory data loading)
│ ├── utils.py # Helper functions, TorchvisionDatasetWrapper
│ ├── tv_transforms.py # Torchvision transforms (NormalizeDepth, CenterPad, CameraIntrinsic)
│ │
│ ├── shapenet.py # ShapeNet dataset
│ ├── bop.py # BOP challenge datasets
│ ├── ycb.py # YCB object dataset
│ ├── tabletop.py # Tabletop scenes
│ ├── completion3d.py # Completion3D benchmark
│ ├── modelnet.py # ModelNet dataset
│ ├── graspnet.py # GraspNet dataset
│ ├── coco.py # COCO instance segmentation
│ └── image.py # ImageFolderDataset (from NVlabs/edm)
└── tests/
Dataset.__getitem__(idx)
│
├── Field.load(obj_dir, index, category) # Load raw data (point cloud, mesh, image)
│ └── Caching (optional) # functools.cache decorator with deepcopy
│
├── Transform.__call__(data) # Augmentation pipeline
│ ├── before_apply(data) # Hook before applying to each key
│ ├── apply(data, key) # Core logic, filtered by apply_to
│ └── after_apply(data) # Hook after applying to all keys
│
└── Return dict # {"inputs": ..., "points": ..., "occ": ...}
Fields define how to load specific data types from disk. All fields inherit from the Field ABC.
from dataset.src.fields import Field
class Field(ABC):
def __init__(self, cachable: bool = True, cache: bool = False):
"""
Args:
cachable: Whether this field supports caching.
cache: Enable functools.cache on load() (returns deepcopy of cached result).
"""
...
@abstractmethod
def load(self, obj_dir: str | Path, index: int, category: int | None) -> dict:
"""Load data from obj_dir. Override in subclasses."""
raise NotImplementedError
| Field | Description | Output |
|---|---|---|
PointCloudField |
Load point clouds (.npz, .npy) | {None: (N,3), "normals": (N,3)} |
PointsField |
Load query points with occupancy | {None: (M,3), "occ": (M,)} |
MeshField |
Load meshes (.obj, .off, .ply) | {None: trimesh.Trimesh} |
ImageField |
Load images | {None: (H,W,C)} |
DepthField |
Load depth maps with camera params | {None: (H,W), "intrinsic": ..., "extrinsic": ...} |
RGBDField |
Load RGB-D pairs | {"image": (H,W,3), "depth": (H,W)} |
BlenderProcRGBDField |
BlenderProc rendered RGB-D | Multiple views with camera params |
BOPField |
BOP format loader | Scene data dict |
VoxelsField |
Load voxel grids (.binvox) | {None: (D,H,W)} |
PartNetField |
PartNet part annotations | {"points": ..., "labels": ...} |
DTUField |
DTU multi-view dataset loader | Multi-view images with cameras |
| Field | Description |
|---|---|
RandomField |
Return random subset from wrapped field (with optional weights) |
MixedField |
Combine multiple fields (merge or priority-select) |
EmptyField |
Return empty dict |
IndexField |
Return item index, name, and path |
data:
fields:
inputs:
type: PointCloudField
file_name: pointcloud.npz
file_key: points
with_normals: false
points:
type: PointsField
file_name: points.npz
points_key: points
occ_key: occupancies
num_samples: 2048
mesh:
type: MeshField
file_name: mesh.obj
from dataset.src.fields import Field
class MyField(Field):
def __init__(self, file_name: str, cache: bool = False):
super().__init__(cachable=True, cache=cache)
self.file_name = file_name
def load(self, obj_dir: str | Path, index: int, category: int | None = None) -> dict:
file_path = Path(obj_dir) / self.file_name
data = np.load(file_path)
return {None: torch.from_numpy(data)}
Transforms augment and process data. All transforms inherit from the Transform ABC. There are 59 concrete transforms organized into categories below.
from dataset.src.transforms import Transform
class Transform(ABC):
def __init__(
self,
apply_to: str | list[str] | tuple[str, ...] | None = None,
allowed: str | list[str] | tuple[str, ...] | None = None,
cachable: bool = False,
):
"""
Args:
apply_to: Key(s) to apply transform to. None = apply to all keys.
Automatically converted to a set for filtering.
allowed: Valid keys for apply_to. Defaults to standard set:
{"inputs", "inputs.depth", "inputs.normals", "inputs.image",
"points", "pointcloud", "pointcloud.normals",
"mesh.vertices", "mesh.normals", "voxels", "bbox",
"partnet.points"}
cachable: Whether this transform's output can be cached.
"""
def before_apply(self, data: dict) -> dict:
"""Hook before applying to each key."""
return data
@abstractmethod
def apply(self, data: dict, key: str | None) -> dict:
"""Override to implement transform logic."""
raise NotImplementedError
def after_apply(self, data: dict) -> dict:
"""Hook after applying to all keys."""
return data
| Transform | Description | Key Params |
|---|---|---|
Rotate |
Random rotation | axes, angles, from_inputs |
Affine |
Affine transformation (from extrinsics) | replace |
Translate |
Translation | offset range |
Scale |
Uniform/non-uniform scaling | scale range |
Normalize |
Center and scale to unit sphere | center, scale, to_front, reference, scale_method |
ApplyPose |
Apply rigid pose transform | |
RefinePose |
Refine pose via ICP | ICP params |
RefinePosePerInstance |
Per-instance pose refinement | ICP params |
| Transform | Description | Key Params |
|---|---|---|
SubsamplePointcloud |
Random subsampling | num_samples |
SubsamplePoints |
Subsample query points | num_samples, in_out_ratio |
AddGaussianNoise |
Add Gaussian noise | std |
CropPointcloud |
Crop to bounding box | bounds |
CropPointcloudWithMesh |
Crop point cloud using mesh | |
CropPoints |
Crop query points | padding |
AxesCutPointcloud |
Planar cut | axes, cut_ratio, rotate_object |
SphereCutPointcloud |
Spherical cut | radius |
SphereMovePointcloud |
Move points along sphere surface | |
ProcessPointcloud |
Downsample + outlier removal | downsample, remove_outlier |
RemoveHiddenPointsFromInputs |
Hidden point removal | viewpoint |
DepthLikePointcloud |
Simulate depth-sensor partial view | rotate_object, upper_hemisphere |
RotatePointcloud |
Rotate point cloud specifically | axes, angles |
| Transform | Description | Key Params |
|---|---|---|
Render |
Render mesh to images (pyrender) | width, height |
RenderPointcloud |
Render point cloud | resolution |
RenderDepthMaps |
Multi-view depth rendering | num_views |
DepthToPointcloud |
Unproject depth to 3D | intrinsics |
ShadingImageFromNormals |
Generate shading from normals |
| Transform | Description |
|---|---|
NormalizeMesh |
Normalize mesh to unit cube |
RotateMesh |
Rotate mesh |
PointcloudFromMesh |
Sample surface points from mesh |
PointsFromMesh |
Sample query points (with occupancy) from mesh |
PointsFromPointcloud |
Sample query points from point cloud |
| Transform | Description |
|---|---|
EdgeNoise |
Add noise at depth discontinuities |
ImageBorderNoise |
Add noise at image borders |
AngleOfIncidenceRemoval |
Remove points by viewing angle |
| Transform | Description |
|---|---|
ImageToTensor |
Convert images to normalized tensors (with optional resize/crop) |
Torchvision |
Wrap any torchvision transform |
| Transform | Description |
|---|---|
VoxelizePointcloud |
Voxelize point cloud |
VoxelizePoints |
Voxelize query points |
BPS |
Basis Point Set encoding |
BoundingBox |
Compute bounding box from reference |
BoundingBoxJitter |
Jitter bounding box |
| Transform | Description |
|---|---|
SdfFromOcc |
Convert occupancy to SDF/TSDF |
SegmentationFromPartNet |
Generate segmentation labels from PartNet |
NormalsCameraCosineSimilarity |
Compute normal-camera cosine similarity |
InputsNormalsFromPointcloud |
Extract normals for inputs from point cloud |
Permute |
Permute tensor dimensions |
MinMaxNumPoints |
Enforce min/max point counts (pad or subsample) |
LoadUncertain |
Load uncertainty data |
FindUncertainPoints |
Identify uncertain query points |
SplitData |
Split data dict into sub-dicts |
Compress |
Compress arrays (float16, packbits) |
Unpack |
Unpack compressed arrays |
CheckDtype |
Validate/convert tensor dtypes |
| Transform | Description |
|---|---|
Return |
Return specific keys only |
RandomChoice |
Randomly select one transform from a list |
RandomApply |
Randomly apply a transform with probability |
KeysToKeep |
Filter output to specified keys |
| Transform | Description |
|---|---|
Visualize |
Debug visualization (plotly) |
SaveData |
Save intermediate results to disk |
from dataset.src.transforms import (
SubsamplePointcloud,
AddGaussianNoise,
Rotate,
Normalize,
)
transforms = [
# Subsample input point cloud to 2048 points
SubsamplePointcloud(apply_to="inputs", num_samples=2048),
# Add noise only during training
AddGaussianNoise(apply_to="inputs", std=0.01),
# Random rotation around Z axis
Rotate(apply_to=["inputs", "points"], axes="z", angles=360),
# Normalize to unit sphere
Normalize(apply_to=["inputs", "points", "mesh.vertices"]),
]
apply_transformsfrom dataset.src.transforms import apply_transforms
# Apply a list of transforms with per-transform timing logs
data = apply_transforms(data, transforms)
The get_dataset() factory routes by cfg.data.train_ds / cfg.data.val_ds / cfg.data.test_ds names. Supported dataset identifiers:
| Identifier | Class | Notes |
|---|---|---|
shapenet* |
ShapeNet |
Any name containing “shapenet” (e.g., shapenet_v1, shapenet_v2) |
completion3d |
Completion3D |
Stanford Completion3D benchmark |
ycb |
YCB |
YCB object dataset (train/val/test + real-data mode) |
modelnet40 |
ModelNet |
ModelNet40 classification dataset |
mnist, fmnist, cifar10 |
torchvision wrappers | Image classification datasets |
coco |
CocoInstanceSegmentation |
COCO instance segmentation |
tabletop* |
TableTop |
Any name containing “tabletop” |
graspnet* |
GraspNetEval |
Any name starting with “graspnet” |
bop_* |
BOP |
BOP challenge (test only): bop_ycbv, bop_lm, bop_hb, bop_tyol |
| Other | ShapeNet |
Falls back to ShapeNet-style loading with custom cfg.dirs[ds] |
Additional dataset classes available for direct use:
ImageFolderDataset — zip/folder image dataset (from NVlabs/edm)SharedDataset / SharedDataLoader — shared-memory wrappers for distributed trainingfrom dataset.src.shapenet import ShapeNet
dataset = ShapeNet(
root="/path/to/shapenet",
split="train", # train | val | test
categories=["chair", "table"], # or None for all
fields={"inputs": field, "points": field},
transforms=transforms,
)
Categories: 57 ShapeNet categories supported (see CATEGORIES_MAP in shapenet.py)
from dataset.src.bop import BOP
dataset = BOP(
root="/path/to/bop",
dataset="ycbv", # ycbv | lm | tless | itodd | ...
split="train_pbr",
fields=fields,
)
from dataset.src.ycb import YCB
dataset = YCB(
root="/path/to/ycb",
split="train",
objects=["002_master_chef_can", "003_cracker_box"],
)
Custom tabletop scene dataset with rendered views.
from dataset.src.tabletop import TableTop
dataset = TableTop(
root="/path/to/tabletop",
split="train",
scene_type="single", # single | multi | clutter
)
Stanford Completion3D benchmark dataset.
from dataset.src.completion3d import Completion3D
dataset = Completion3D(
root="/path/to/completion3d",
split="train",
)
from dataset.src.modelnet import ModelNet
dataset = ModelNet(
root="/path/to/modelnet",
version=40, # 10 | 40
split="train",
)
from dataset.src.graspnet import GraspNetEval
dataset = GraspNetEval(
root="/path/to/graspnet",
split="test",
)
All configuration uses Hydra. Config files live in conf/ at the main repo root.
data:
train_ds: shapenet_v1 # Dataset identifier(s) — can be a list for multi-dataset
val_ds: null # Defaults to train_ds if null
test_ds: null # Defaults to val_ds if null
categories: [chair, table] # Categories to load (null = all)
cache: false # Cache loaded data
hash_items: false # Use hashed item paths
sdf_from_occ: false # Convert occupancy to SDF
dither: false # Dither float32 tensors during training
inputs:
type: pointcloud # pointcloud | depth | image | rgbd | partial | depth_like |
# kinect | shading | normals | render
dim: 3 # Point dimension (3=xyz, 6=xyz+normals)
num_points: 2048 # Points to load
project: false # Project depth to point cloud
cache: false # Cache rendered inputs
load_random: true # Random view selection
voxelize: 0 # Voxelize inputs (0 = disabled, else resolution)
permute: false # Permute tensor dimensions
bbox: false # Compute bounding box
min_num_points: 0 # Minimum points (pad if fewer)
max_num_points: 0 # Maximum points (subsample if more)
# Image inputs
width: 640
height: 480
resize: null # Resize dimensions
crop: 0 # Center crop size
normals: false # Load normal maps
# BPS encoding
bps:
num_points: 0
resolution: 0
method: null
feature: null
basis: null
# FPS sampling
fps:
num_points: 0
pointcloud:
from_mesh: false # Sample from mesh instead of loading
normals: false # Load normals
bbox: false # Compute bounding box
train:
num_points: 100000 # Surface points for training
val:
num_points: 100000
points:
dim: 3 # Query point dimension
from_mesh: false # Sample from mesh
from_pointcloud: false # Sample from point cloud
subsample: true # Enable subsampling
crop: false # Crop to bounds
voxelize: 0 # Voxelize (0 = disabled)
cache: false # Cache
bbox: false # Compute bounding box
min_num_points: 0 # Minimum points
train:
num_samples: 2048 # Query points per sample
ratio: 0.5 # Surface vs volume ratio
val:
num_samples: 100000
aug:
rotate: z # Rotation axes (x|y|z|xy|xyz|cam|none)
scale: [0.9, 1.1] # Scale range
translate: 0.1 # Translation range
noise: 0.01 # Gaussian noise std
# Point cloud specific
downsample: false # Enable downsampling
remove_hidden: false # Remove occluded points
upper_hemisphere: true # Camera in upper hemisphere only
remove_outlier: false # Statistical outlier removal
move_sphere: false # Sphere-based point movement
bbox_jitter: 0 # Bounding box jitter amount
# Depth specific
edge_noise: false # Add edge artifacts
remove_angle: false # Remove by incidence angle
border_noise: false # Add border noise to images
train:
no_aug: false # Disable augmentation
val:
no_aug: true
test:
no_aug: true
norm:
center: "" # Center axes (e.g., "xyz")
scale: false # Scale to unit sphere
to_front: false # Rotate to front
offset: null # Translation offset
true_height: false # Use true height for normalization
reference: null # Reference for normalization (null | "mesh" | "pointcloud")
method: null # Scale method
padding: 0.1 # Padding for bounding box
mesh:
norm: true # Normalize mesh
rot: null # Pre-rotation angles [x, y, z]
bbox: false # Compute bounding box
Create dataset/src/mydataset.py:
from pathlib import Path
from torch.utils.data import Dataset
from .transforms import Transform, apply_transforms
class MyDataset(Dataset):
def __init__(
self,
root: str,
split: str = "train",
fields: dict | None = None,
transforms: list[Transform] | None = None,
categories: list | None = None,
):
self.root = Path(root)
self.split = split
self.fields = fields or {}
self.transforms = transforms or []
# Build item list
self.items = self._load_split()
def _load_split(self):
"""Load split file or scan directory."""
split_file = self.root / f"{self.split}.lst"
if split_file.exists():
return split_file.read_text().strip().split("\n")
return list(self.root.glob("*"))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
path = self.root / item
# Load fields
data = {}
for name, field in self.fields.items():
data[name] = field.load(str(path), idx, category=None)
# Apply transforms (with timing logs)
data = apply_transforms(data, self.transforms)
return data
dataset/src/__init__.pyfrom .mydataset import MyDataset
dataset/__init__.py:get_dataset()The factory routes by matching cfg.data.train_ds string patterns:
# In the train/val/test loops inside get_dataset():
elif ds == "mydataset":
data = MyDataset(
root=cfg.dirs[ds],
split=split,
fields=fields,
transforms=get_transformations(cfg, split),
)
Pattern matching rules in get_dataset():
"shapenet" in ds — any name containing “shapenet”"tabletop" in ds — any name containing “tabletop”ds.startswith("graspnet") — any name starting with “graspnet”"bop" in ds — BOP datasets (test split only)completion3d, ycb, modelnet40, mnist, fmnist, cifar10, cococfg.dirs[ds]Create conf/mydataset.yaml:
defaults:
- config
- _self_
data:
train_ds: mydataset
val_ds: mydataset
Add the data root to conf/dirs/default.yaml:
mydataset: /path/to/mydataset
Add to dataset/src/transforms.py:
class MyTransform(Transform):
@get_args() # Captures constructor args for serialization/logging
def __init__(
self,
my_param: float = 1.0,
apply_to: str | list[str] | None = None,
cachable: bool = False,
):
super().__init__(apply_to=apply_to, cachable=cachable)
self.my_param = my_param
def apply(self, data: DataDict, key: str | None) -> DataDict:
# key is None when apply_to=None (applies to whole dict)
# key is a string when apply_to is set (e.g., "inputs", "points")
if key is not None:
data[key] = data[key] * self.my_param
return data
__all__At the bottom of transforms.py, add "MyTransform" to the __all__ list. This auto-exports it through dataset/src/__init__.py.
dataset/__init__.pyAdd to the explicit import block:
from .src.transforms import MyTransform
get_transformations() (optional)If the transform should be automatically included based on config flags, add it to the get_transformations() function in dataset/__init__.py:
if cfg.aug.my_flag:
transformations.append(MyTransform(apply_to="inputs", my_param=cfg.aug.my_param))
Fields support automatic caching via the cache constructor parameter. The cache stores results with functools.cache and returns deep copies to prevent mutation:
field = PointCloudField(file_name="pointcloud.npz", cache=True)
# First call loads from disk, subsequent calls return deepcopy of cached result
Some transforms support the cachable flag, which signals to the SharedDataset infrastructure that their output can be stored in shared memory:
transform = VoxelizePointcloud(apply_to="inputs", resolution=32, cachable=True)
SharedDataset and SharedDataLoader enable caching dataset items in shared memory for distributed training:
from dataset.src.shared import SharedDataset, SharedDataLoader
shared_ds = SharedDataset(dataset)
loader = SharedDataLoader(shared_ds, batch_size=32)
Configure via:
load:
share_memory: true
load:
weighted: true
data:
cache: true
points:
train:
num_samples: 2048
val:
num_samples: 100000 # More for accurate eval
load:
share_memory: true
inputs:
voxelize: 32 # Produces (32, 32, 32) voxel grid
load:
keys_to_keep: [inputs, points, occ]