Skip to content

utils

utils

Modules:

Name Description
articulation_utils
asset_names
benchmark_utils

Benchmark utilities for kinematics outlier detection in trajectory H5 files.

camera_utils
constants
controller_utils
depth_utils

Utilities for depth image encoding and decoding.

devices
distance_transform_utils
eval_camera_randomization_utils

Level → value scaling for camera and light randomization.

eval_utils

Evaluation utilities for logging stats and videos to wandb.

fisheye_warping

GPU-accelerated fisheye lens distortion warping for camera images.

function_utils
grasp_sample

This module contains functionality for filtering and sampling grasps based on heuristics.

grasps

This module contains functionality for loading grasps from registered grasp libraries.

lazy_loading_utils
lemma_utils
license_utils
linalg_utils
mj_model_and_data_utils
mp_logging
mujoco_scene_utils
object_metadata
object_retriever
patch_renderer_flags

Import this module to configure the renderer flags for the current platform.

pose
profiler_utils
rendering_utils
sampler_utils
save_utils
scene_maps
scene_metadata_utils
spatial_utils

Quaternions are assumed to be scalar first!

synset_utils
task_relevant_objects_and_workspace_utils

Derive task-relevant object names and workspace center from task config fields.

test_utils

Shared utilities for data generation tests (Franka, RUM, etc.).

video_utils

Copied from video2sim_pipeline/video2sim/utils/video_utils.py

articulation_utils

Functions:

Name Description
gather_joint_info
step_circular_path

joint_info:

step_linear_path
visualize_path

Comprehensive visualization of the gripper base path and finger center arc.

Attributes:

Name Type Description
GRIPPER_LENGTH

GRIPPER_LENGTH module-attribute

GRIPPER_LENGTH = 0.125

gather_joint_info

gather_joint_info(model, data, joint_name_or_index)
Source code in molmo_spaces/utils/articulation_utils.py
def gather_joint_info(model, data, joint_name_or_index):
    body_id = model.joint(joint_name_or_index).bodyid[0]
    # root_body_id = model.body(body_id).rootid[0]
    body_joint_qpos = model.joint(joint_name_or_index).qposadr[0]
    joint_range = model.joint(joint_name_or_index).range
    max_range = joint_range[1] if joint_range[1] != 0 else joint_range[0]

    joint_info = {
        "joint_axis": model.joint(joint_name_or_index).axis,
        "joint_position": model.joint(joint_name_or_index).pos,
        "joint_range": model.joint(joint_name_or_index).range,
        "joint_pos": data.qpos[body_joint_qpos],
        "joint_body_position": data.xpos[body_id],
        "joint_body_orientation": data.xmat[body_id].reshape(3, 3),  # Body orientation matrix
        "max_range": max_range,
        "joint_type": model.joint(joint_name_or_index).type,
        "joint_qpos_adr": body_joint_qpos,
        "joint_id": model.joint(joint_name_or_index).id,
    }
    return joint_info

step_circular_path

step_circular_path(current_pos, current_quat, joint_info, max_joint_angle, n_waypoints=10, gripper_length=0)
joint_info

joint_body_position joint_axis joint_body_orientation joint_position joint_range joint_pos

Source code in molmo_spaces/utils/articulation_utils.py
def step_circular_path(
    current_pos,
    current_quat,
    joint_info,
    max_joint_angle,
    n_waypoints=10,
    gripper_length=0,
):
    """
    joint_info:
        joint_body_position
        joint_axis
        joint_body_orientation
        joint_position
        joint_range
        joint_pos
    """

    def rotation_matrix_from_axis_angle(axis, angle):
        """Create rotation matrix from axis and angle using scipy's reliable implementation"""
        axis = axis / np.linalg.norm(axis)
        # Use scipy's implementation which is more reliable

        R_matrix = R.from_rotvec(axis * angle).as_matrix()
        return R_matrix

    ## extract joint info
    joint_body_position = joint_info["joint_body_position"]
    joint_axis_local = joint_info["joint_axis"]

    # Convert joint axis from local body frame to global frame
    # Get the body's orientation matrix
    body_orientation = joint_info["joint_body_orientation"]
    joint_axis = body_orientation @ joint_axis_local

    # joint position is in the joint frame, so we need to convert it to the world frame
    # joint_position = joint_info["joint_position"] + need to convert to world frame by multiplying body orientation
    joint_position = body_orientation @ joint_info["joint_position"] + joint_body_position

    # Use gripper position to find the arc that gripper follows
    handle_position = current_pos
    handle_orientation = current_quat

    # get offset from joint to gripper
    handle_offset = handle_position - joint_position

    if np.abs(joint_axis[2]) > 0.9:
        # if rotating along the global z axis, make the height same
        joint_position[2] = handle_position[2]

        # For Z-axis rotation, we need to ensure the rotation is in the XY plane
        # The gripper offset should be in the XY plane only
        handle_offset_xy = handle_offset.copy()
        handle_offset_xy[2] = 0  # Zero out Z component for XY plane rotation
        handle_offset = handle_offset_xy
    current_joint_angle = joint_info["joint_pos"]

    # Calculate relative angles (change from current to max)
    angle_change = max_joint_angle - current_joint_angle
    angles = np.linspace(0, angle_change, n_waypoints + 1)
    if np.abs(angle_change) < 0.1:
        angles = np.linspace(0, -max_joint_angle, n_waypoints + 1)

    # Get gripper orientation matrix
    gripper_orientation_matrix = R.from_quat(handle_orientation, scalar_first=True).as_matrix()

    # Calculate the finger center offset from joint (this follows the circular arc)
    # The finger center should be at the handle position initially
    finger_center_offset_from_joint = handle_position - joint_position

    if np.abs(joint_axis[2]) > 0.9:
        # For Z-axis rotation, ensure the finger center offset is in XY plane only
        finger_center_offset_from_joint_xy = finger_center_offset_from_joint.copy()
        finger_center_offset_from_joint_xy[2] = 0
        finger_center_offset_from_joint = finger_center_offset_from_joint_xy

    path = {"mocap_pos": [], "mocap_quat": []}

    for angle in angles:
        # Calculate rotation matrix for this angle
        R_matrix = rotation_matrix_from_axis_angle(joint_axis, angle)

        # Rotate finger center offset around the joint (finger center follows circular arc)
        rotated_finger_center_offset = R_matrix @ finger_center_offset_from_joint

        # Calculate new finger center position (this follows the circular arc)
        next_finger_center_pos = joint_position + rotated_finger_center_offset

        # Calculate new gripper orientation by applying the joint rotation to the original orientation
        next_gripper_orientation_matrix = R_matrix @ gripper_orientation_matrix
        next_gripper_z_axis = next_gripper_orientation_matrix[:, 2]  # Z-axis points towards handle

        # Calculate gripper base position by offsetting from finger center
        # Gripper base is offset by half gripper length in the negative Z direction (away from handle)
        gripper_base_offset_from_finger = -0.5 * gripper_length * next_gripper_z_axis
        next_gripper_base_pos = next_finger_center_pos + gripper_base_offset_from_finger

        # Convert to quaternion
        next_quat = R.from_matrix(next_gripper_orientation_matrix).as_quat(scalar_first=True)

        # Use gripper base position for the mocap trajectory (but finger center follows the arc)
        path["mocap_pos"].append(next_gripper_base_pos)
        path["mocap_quat"].append(next_quat)

    visualize = False
    if visualize:
        visualize_path(path, joint_position=joint_position)

    return path

step_linear_path

step_linear_path(to_handle_dist, current_pos, current_quat, step_size, is_reverse=False, gripper_length=0)
Source code in molmo_spaces/utils/articulation_utils.py
def step_linear_path(
    to_handle_dist,
    current_pos,
    current_quat,
    step_size,
    is_reverse=False,
    gripper_length=0,
):
    path = {"mocap_pos": [], "mocap_quat": []}
    path["mocap_pos"].append(current_pos)
    path["mocap_quat"].append(current_quat)

    dist = np.linalg.norm(to_handle_dist)

    if is_reverse:
        dist += gripper_length
    else:
        dist -= gripper_length

    # path forward
    for _i in range(int(dist / step_size)):
        # in direction of to_handle_dist
        angle = np.arctan2(to_handle_dist[1], to_handle_dist[0])
        if not is_reverse:
            next_pos = current_pos + step_size * np.array([np.cos(angle), np.sin(angle), 0])
        else:
            next_pos = current_pos - step_size * np.array([np.cos(angle), np.sin(angle), 0])
        path["mocap_pos"].append(next_pos)
        path["mocap_quat"].append(current_quat)
        dist -= np.linalg.norm(next_pos - current_pos)
        current_pos = next_pos
    return path

visualize_path

visualize_path(path, title='Gripper Base Path Visualization', save_path=None, joint_position=None, show_finger_center=True)

Comprehensive visualization of the gripper base path and finger center arc.

Parameters:

Name Type Description Default
path

Dictionary with 'mocap_pos' and 'mocap_quat' lists (representing gripper base positions)

required
title

Title for the plot

'Gripper Base Path Visualization'
save_path

Optional path to save the plot

None
joint_position

Optional joint position to visualize

None
show_finger_center

If True, also show the finger center arc for reference

True
Source code in molmo_spaces/utils/articulation_utils.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
def visualize_path(
    path,
    title="Gripper Base Path Visualization",
    save_path=None,
    joint_position=None,
    show_finger_center=True,
):
    """
    Comprehensive visualization of the gripper base path and finger center arc.

    Args:
        path: Dictionary with 'mocap_pos' and 'mocap_quat' lists (representing gripper base positions)
        title: Title for the plot
        save_path: Optional path to save the plot
        joint_position: Optional joint position to visualize
        show_finger_center: If True, also show the finger center arc for reference
    """
    if not path or "mocap_pos" not in path or len(path["mocap_pos"]) == 0:
        print("No path data to visualize")
        return

    # Convert to numpy arrays for easier manipulation
    # Handle case where positions might be tuples or lists
    if len(path["mocap_pos"]) > 0:
        # Convert each position to numpy array if it isn't already
        positions = np.array(
            [np.array(pos) if not isinstance(pos, np.ndarray) else pos for pos in path["mocap_pos"]]
        )
    else:
        positions = np.array([])

    # Handle quaternions similarly
    if "mocap_quat" in path and len(path["mocap_quat"]) > 0:
        quaternions = np.array(
            [
                np.array(quat) if not isinstance(quat, np.ndarray) else quat
                for quat in path["mocap_quat"]
            ]
        )
    else:
        quaternions = None

    print(f"Visualizing gripper base path with {len(positions)} points")
    print(
        f"Base position range: X[{positions[:, 0].min():.3f}, {positions[:, 0].max():.3f}], "
        f"Y[{positions[:, 1].min():.3f}, {positions[:, 1].max():.3f}], "
        f"Z[{positions[:, 2].min():.3f}, {positions[:, 2].max():.3f}]"
    )

    # Calculate finger center positions for reference (if quaternions available)
    finger_center_positions = None
    if show_finger_center and quaternions is not None and len(quaternions) > 0:
        finger_center_positions = []
        for _i, (base_pos, quat) in enumerate(zip(positions, quaternions)):
            # Calculate finger center position by offsetting from base
            rot_matrix = R.from_quat(quat, scalar_first=True).as_matrix()
            gripper_z_axis = rot_matrix[:, 2]  # Z-axis points towards handle
            finger_offset = 0.5 * GRIPPER_LENGTH * gripper_z_axis
            finger_pos = base_pos + finger_offset
            finger_center_positions.append(finger_pos)
        finger_center_positions = np.array(finger_center_positions)
        print(
            f"Finger center position range: X[{finger_center_positions[:, 0].min():.3f}, {finger_center_positions[:, 0].max():.3f}], "
            f"Y[{finger_center_positions[:, 1].min():.3f}, {finger_center_positions[:, 1].max():.3f}], "
            f"Z[{finger_center_positions[:, 2].min():.3f}, {finger_center_positions[:, 2].max():.3f}]"
        )

    # Create figure with subplots
    fig = plt.figure(figsize=(15, 10))

    # 2D Top-down view (X-Y plane)
    ax1 = plt.subplot(2, 2, 1)
    ax1.plot(
        positions[:, 0],
        positions[:, 1],
        "b-o",
        linewidth=2,
        markersize=4,
        label="Gripper Base Path",
    )
    ax1.scatter(
        positions[0, 0], positions[0, 1], color="green", s=100, label="Base Start", zorder=5
    )
    ax1.scatter(positions[-1, 0], positions[-1, 1], color="red", s=100, label="Base End", zorder=5)

    # Add finger center path if available
    if finger_center_positions is not None:
        ax1.plot(
            finger_center_positions[:, 0],
            finger_center_positions[:, 1],
            "r--s",
            linewidth=1,
            markersize=3,
            alpha=0.7,
            label="Finger Center Arc",
        )
        ax1.scatter(
            finger_center_positions[0, 0],
            finger_center_positions[0, 1],
            color="darkgreen",
            s=80,
            marker="s",
            label="Finger Start",
            zorder=5,
        )
        ax1.scatter(
            finger_center_positions[-1, 0],
            finger_center_positions[-1, 1],
            color="darkred",
            s=80,
            marker="s",
            label="Finger End",
            zorder=5,
        )

    # Add joint position if provided
    if joint_position is not None:
        ax1.scatter(
            joint_position[0],
            joint_position[1],
            color="purple",
            s=200,
            marker="*",
            label="Joint",
            zorder=6,
        )
        # Draw a line from joint to start position (finger center for reference)
        if finger_center_positions is not None:
            ax1.plot(
                [joint_position[0], finger_center_positions[0, 0]],
                [joint_position[1], finger_center_positions[0, 1]],
                "k--",
                alpha=0.5,
                linewidth=1,
                label="Joint-Finger Offset",
            )

    ax1.set_xlabel("X (m)")
    ax1.set_ylabel("Y (m)")
    ax1.set_title("Top-down View (X-Y)")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_aspect("equal")

    # 2D Side view (X-Z plane)
    ax2 = plt.subplot(2, 2, 2)
    ax2.plot(positions[:, 0], positions[:, 2], "r-o", linewidth=2, markersize=4, label="Path")
    ax2.scatter(positions[0, 0], positions[0, 2], color="green", s=100, label="Start", zorder=5)
    ax2.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End", zorder=5)

    # Add joint position if provided
    if joint_position is not None:
        ax2.scatter(
            joint_position[0],
            joint_position[2],
            color="purple",
            s=200,
            marker="*",
            label="Joint",
            zorder=6,
        )
        # Draw a line from joint to start position
        ax2.plot(
            [joint_position[0], positions[0, 0]],
            [joint_position[2], positions[0, 2]],
            "k--",
            alpha=0.5,
            linewidth=1,
            label="Joint-Start Offset",
        )

    ax2.set_xlabel("X (m)")
    ax2.set_ylabel("Z (m)")
    ax2.set_title("Side View (X-Z)")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_aspect("equal")

    # 2D Front view (Y-Z plane)
    ax3 = plt.subplot(2, 2, 3)
    ax3.plot(positions[:, 1], positions[:, 2], "g-o", linewidth=2, markersize=4, label="Path")
    ax3.scatter(positions[0, 1], positions[0, 2], color="green", s=100, label="Start", zorder=5)
    ax3.scatter(positions[-1, 1], positions[-1, 2], color="red", s=100, label="End", zorder=5)

    # Add joint position if provided
    if joint_position is not None:
        ax3.scatter(
            joint_position[1],
            joint_position[2],
            color="purple",
            s=200,
            marker="*",
            label="Joint",
            zorder=6,
        )
        # Draw a line from joint to start position
        ax3.plot(
            [joint_position[1], positions[0, 1]],
            [joint_position[2], positions[0, 2]],
            "k--",
            alpha=0.5,
            linewidth=1,
            label="Joint-Start Offset",
        )

    ax3.set_xlabel("Y (m)")
    ax3.set_ylabel("Z (m)")
    ax3.set_title("Front View (Y-Z)")
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_aspect("equal")

    # 3D view
    ax4 = plt.subplot(2, 2, 4, projection="3d")
    ax4.plot(
        positions[:, 0],
        positions[:, 1],
        positions[:, 2],
        "b-o",
        linewidth=2,
        markersize=4,
        label="Path",
    )
    ax4.scatter(
        positions[0, 0],
        positions[0, 1],
        positions[0, 2],
        color="green",
        s=100,
        label="Start",
        zorder=5,
    )
    ax4.scatter(
        positions[-1, 0],
        positions[-1, 1],
        positions[-1, 2],
        color="red",
        s=100,
        label="End",
        zorder=5,
    )

    # Add joint position if provided
    if joint_position is not None:
        ax4.scatter(
            joint_position[0],
            joint_position[1],
            joint_position[2],
            color="purple",
            s=200,
            marker="*",
            label="Joint",
            zorder=6,
        )
        # Draw a line from joint to start position to show the offset
        ax4.plot(
            [joint_position[0], positions[0, 0]],
            [joint_position[1], positions[0, 1]],
            [joint_position[2], positions[0, 2]],
            "k--",
            alpha=0.5,
            linewidth=1,
            label="Joint-Start Offset",
        )

    ax4.set_xlabel("X (m)")
    ax4.set_ylabel("Y (m)")
    ax4.set_zlabel("Z (m)")
    ax4.set_title("3D View")

    # Add legend for orientation arrows
    from matplotlib.patches import FancyArrowPatch

    legend_elements = [
        FancyArrowPatch((0, 0), (0.1, 0), color="red", linewidth=2, label="Gripper X-axis"),
        FancyArrowPatch((0, 0), (0.1, 0), color="green", linewidth=2, label="Gripper Y-axis"),
        FancyArrowPatch((0, 0), (0.1, 0), color="blue", linewidth=2, label="Gripper Z-axis"),
    ]
    ax4.legend(handles=legend_elements, loc="upper right")

    # Add orientation arrows if quaternions are available
    if quaternions is not None and len(quaternions) > 0:
        # Show orientation at regular intervals along the path
        n_arrows = min(8, len(positions))  # Show up to 8 arrows
        arrow_indices = np.linspace(0, len(positions) - 1, n_arrows, dtype=int)

        for i, idx in enumerate(arrow_indices):
            if idx < len(quaternions):
                # Convert quaternion to rotation matrix
                rot_matrix = R.from_quat(quaternions[idx], scalar_first=True).as_matrix()

                # Create arrows representing gripper orientation (X, Y, Z axes)
                arrow_length = 0.15
                origin = positions[idx]

                # X-axis (red), Y-axis (green), Z-axis (blue) of gripper
                axes = [
                    (rot_matrix[:, 0], "red", "X"),  # X-axis
                    (rot_matrix[:, 1], "green", "Y"),  # Y-axis
                    (rot_matrix[:, 2], "blue", "Z"),  # Z-axis
                ]

                for axis_direction, color, _axis_name in axes:
                    ax4.quiver(
                        origin[0],
                        origin[1],
                        origin[2],
                        axis_direction[0],
                        axis_direction[1],
                        axis_direction[2],
                        length=arrow_length,
                        color=color,
                        alpha=0.8,
                        linewidth=2,
                    )

                # Add text label for the point
                ax4.text(
                    origin[0],
                    origin[1],
                    origin[2],
                    f"P{i}",
                    fontsize=8,
                    color="black",
                    weight="bold",
                )

    plt.tight_layout()
    plt.suptitle(title, fontsize=16, y=0.98)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Path visualization saved to {save_path}")

    plt.show()

    # Print path statistics
    print("\nPath Statistics:")
    print("Gripper Base Path:")
    print(f"  Total distance: {np.sum(np.linalg.norm(np.diff(positions, axis=0), axis=1)):.3f} m")
    print(f"  Number of waypoints: {len(positions)}")
    print(f"  Start position: {positions[0]}")
    print(f"  End position: {positions[-1]}")

    if finger_center_positions is not None:
        finger_distance = np.sum(np.linalg.norm(np.diff(finger_center_positions, axis=0), axis=1))
        print("Finger Center Arc:")
        print(f"  Total distance: {finger_distance:.3f} m")
        print(f"  Start position: {finger_center_positions[0]}")
        print(f"  End position: {finger_center_positions[-1]}")

    if joint_position is not None:
        print(f"Joint position: {joint_position}")
        print(
            f"Distance from joint to base start: {np.linalg.norm(positions[0] - joint_position):.3f} m"
        )
        print(
            f"Distance from joint to base end: {np.linalg.norm(positions[-1] - joint_position):.3f} m"
        )
        if finger_center_positions is not None:
            print(
                f"Distance from joint to finger start: {np.linalg.norm(finger_center_positions[0] - joint_position):.3f} m"
            )
            print(
                f"Distance from joint to finger end: {np.linalg.norm(finger_center_positions[-1] - joint_position):.3f} m"
            )

    # Print orientation statistics if available
    if quaternions is not None and len(quaternions) > 0:
        print(f"Start orientation (quaternion): {quaternions[0]}")
        print(f"End orientation (quaternion): {quaternions[-1]}")

        # Calculate orientation change
        start_rot = R.from_quat(quaternions[0], scalar_first=True)
        end_rot = R.from_quat(quaternions[-1], scalar_first=True)
        orientation_diff = start_rot.inv() * end_rot
        angle_diff = orientation_diff.magnitude()
        print(f"Total orientation change: {np.degrees(angle_diff):.1f}°")

    return fig

asset_names

Functions:

Name Description
get_child_body_ids
get_child_body_names
get_thor_name

get_child_body_ids

get_child_body_ids(model: MjModel, parent_id: int) -> list[int]
Source code in molmo_spaces/utils/asset_names.py
def get_child_body_ids(model: MjModel, parent_id: int) -> list[int]:
    # Get all body parent IDs
    parent_ids = model.body_parentid  # numpy array of length nbody
    # Find indices where parent is the given ID
    return [i for i, pid in enumerate(parent_ids) if pid == parent_id]

get_child_body_names

get_child_body_names(model: MjModel, parent_id: int) -> list[str]
Source code in molmo_spaces/utils/asset_names.py
def get_child_body_names(model: MjModel, parent_id: int) -> list[str]:
    child_ids = get_child_body_ids(model, parent_id)
    return [model.body(i).name for i in child_ids]

get_thor_name

get_thor_name(model, pickup_obj)
Source code in molmo_spaces/utils/asset_names.py
def get_thor_name(model, pickup_obj):
    child_names = get_child_body_names(model, pickup_obj.object_id)
    object_name = child_names[0] if len(child_names) > 0 else pickup_obj.name

    if "|" in object_name:  # name from proctor
        name_end = object_name.split("|")[-1]
        match = re.search(r"^\d+_([A-Za-z_]+[\d_]+)", name_end)
        return match.group(1).strip("_")
    else:  # name from ithor
        name_end = object_name.replace(pickup_obj.name + "_", "")
        match = re.search(r"^([A-Za-z_]+[\d_]+)", name_end)
        return match.group(1).strip("_")

benchmark_utils

Benchmark utilities for kinematics outlier detection in trajectory H5 files.

NOTE ON TEMPORAL INDEXING: All sensor arrays (qpos, cmd, jpr) share the same index space, but within a step() the controller target is set before physics runs, so qpos[t] only partially converges toward cmd[t] (~30% per step). We therefore use qpos[t+1] instead of qpos[t] to measure tracking quality::

tracking_error[t]          = qpos[t+1] - cmd[t]
relative_tracking_error[t] = (qpos[t+1] - cmd[t]) / jpr[t]

NOTE ON COLLISION DETECTION: The better way to detect collisions would be via residual torques (measured minus expected from rigid-body dynamics), but no torque sensors are currently recorded. Tracking error is a reasonable proxy --collisions cause position deviations-- though less sensitive than torques.

Functions:

Name Description
compute_bounds_std
episodes_with_kinematics_outliers

Find episodes with kinematics outliers across H5 trajectory files.

resolve_asset_id

Resolve the asset ID (UID) for a task object by name.

save_outlier_gifs

Save GIFs of merged outlier segments for visual inspection.

save_signal_histograms

Collect raw relative tracking error values and save histograms as PNGs.

Attributes:

Name Type Description
THOR_CAT_SIMPLIFY
log

THOR_CAT_SIMPLIFY module-attribute

THOR_CAT_SIMPLIFY = {'saltshaker': 's/p shaker', 'peppershaker': 's/p shaker', 'tomato': 'fruit', 'apple': 'fruit', 'butterknife': 'knife', 'boiler': 'kettle', 'winebottle': 'bottle', 'atomizer': 'spray bottle', 'remotecontrol': 'remote control', 'soapdispenser': 'soap dispenser', 'tissuepaper': 'tissue paper'}

log module-attribute

log = getLogger(__name__)

compute_bounds_std

compute_bounds_std(stats: dict, std_mult: float) -> dict[tuple[str, int], tuple[float, float, float, float]]
Source code in molmo_spaces/utils/benchmark_utils.py
def compute_bounds_std(
    stats: dict,
    std_mult: float,
) -> dict[tuple[str, int], tuple[float, float, float, float]]:
    bounds = {}
    for key, s in stats.items():
        if s["count"] < 10:
            continue

        mean = s["mean"]
        variance = s["M2"] / s["count"]
        std = np.sqrt(variance)

        lower = mean - std_mult * std
        upper = mean + std_mult * std

        bounds[key] = (lower, upper, mean, std)

    return bounds

episodes_with_kinematics_outliers

episodes_with_kinematics_outliers(data_path: Path | str, max_files: int | None = None, num_workers: int = 32, std_mult: float = 8.0, action_groups: Collection[str] = ('arm',), skip_first_n_steps: int = 1, min_joint_pos_rel_magnitude: float = 0.015, std_mult_clip_sigma: float = 4.0, std_mult_negative_only: bool = True, print_stats: bool = False) -> tuple[list[dict], dict]

Find episodes with kinematics outliers across H5 trajectory files.

Uses relative tracking error (qpos[t+1] - cmd[t]) / jpr[t] as the outlier signal (see NOTE ON TEMPORAL INDEXING in module docstring).

Parameters:

Name Type Description Default
data_path Path | str

Root directory containing house_/trajectories.h5 files.

required
max_files int | None

Cap on the number of H5 files to process (for testing).

None
num_workers int

Parallel workers for stats collection and outlier detection.

32
std_mult float

Number of standard deviations beyond which a value is an outlier.

8.0
action_groups Collection[str]

Which move-group keys to examine (e.g. ("arm",)).

('arm',)
skip_first_n_steps int

Ignore the first N timesteps of each trajectory.

1
min_joint_pos_rel_magnitude float

Minimum absolute joint_pos_rel value (in radians) below which the sample is discarded. Small commanded deltas with possibly high undershoot are assumed to occur during the grasping approach. Choose this threshold so that the grasping mode fades out, leaving only the free-space motion regime.

0.015
std_mult_clip_sigma float

Number of sigmas for iterative sigma-clipping to robustly estimate the spread.

4.0
std_mult_negative_only bool

When True, only flag values below the lower bound (undershooting). Positive overshooting is accepted.

True
print_stats bool

Print per-dimension statistics.

False

Returns:

Type Description
tuple[list[dict], dict]

Tuple of (outlier_episodes, bounds).

Source code in molmo_spaces/utils/benchmark_utils.py
def episodes_with_kinematics_outliers(
    data_path: Path | str,
    max_files: int | None = None,
    num_workers: int = 32,
    std_mult: float = 8.0,  # to determine outlier
    action_groups: Collection[str] = ("arm",),
    skip_first_n_steps: int = 1,
    min_joint_pos_rel_magnitude: float = 1.5e-2,
    std_mult_clip_sigma: float = 4.0,
    std_mult_negative_only: bool = True,
    print_stats: bool = False,
) -> tuple[list[dict], dict]:
    """Find episodes with kinematics outliers across H5 trajectory files.

    Uses relative tracking error ``(qpos[t+1] - cmd[t]) / jpr[t]`` as the
    outlier signal (see NOTE ON TEMPORAL INDEXING in module docstring).

    Args:
        data_path: Root directory containing house_*/trajectories*.h5 files.
        max_files: Cap on the number of H5 files to process (for testing).
        num_workers: Parallel workers for stats collection and outlier detection.
        std_mult: Number of standard deviations beyond which a value is an outlier.
        action_groups: Which move-group keys to examine (e.g. ``("arm",)``).
        skip_first_n_steps: Ignore the first N timesteps of each trajectory.
        min_joint_pos_rel_magnitude: Minimum absolute ``joint_pos_rel`` value
            (in radians) below which the sample is discarded.  Small
            commanded deltas with possibly high undershoot
            are assumed to occur during the grasping approach.
            Choose this threshold so that the grasping mode
            fades out, leaving only the free-space motion regime.
        std_mult_clip_sigma: Number of sigmas for iterative sigma-clipping
            to robustly estimate the spread.
        std_mult_negative_only: When True, only flag values below the lower
            bound (undershooting).  Positive overshooting is accepted.
        print_stats: Print per-dimension statistics.

    Returns:
        Tuple of ``(outlier_episodes, bounds)``.
    """
    data_path = Path(data_path)

    h5_files = list(data_path.glob("house_*/trajectories*.h5"))
    if max_files:
        h5_files = h5_files[:max_files]

    print(f"Found {len(h5_files)} H5 files")
    print(f"Using {num_workers} workers")
    print(f"Outlier threshold: mean {'±' if not std_mult_negative_only else '-'} {std_mult} * std")

    print("\n[Pass 1] Computing stats...")

    worker_args = [
        (
            str(f),
            action_groups,
            skip_first_n_steps,
            min_joint_pos_rel_magnitude,
        )
        for f in h5_files
    ]

    all_reltrack_raw: dict[tuple[str, int], list[float]] = defaultdict(list)
    with Pool(num_workers) as pool:
        for reltrack_raw in tqdm(
            pool.imap_unordered(_collect_stats_worker, worker_args, chunksize=20),
            total=len(worker_args),
            desc="Stats",
        ):
            for key, vals in reltrack_raw.items():
                all_reltrack_raw[key].extend(vals)
    merged_stats: dict = {}

    if all_reltrack_raw:
        print(f"\n  Sigma-clipping relative tracking error (clip={std_mult_clip_sigma}σ) ...")
        for key, raw_vals in sorted(all_reltrack_raw.items()):
            clipped_mean, clipped_std, n_kept = _sigma_clip(
                np.array(raw_vals), clip_sigma=std_mult_clip_sigma
            )
            merged_stats[key] = {
                "count": n_kept,
                "mean": clipped_mean,
                "M2": clipped_std**2 * n_kept,
            }
            ag, dim = key
            print(
                f"    {ag}[{dim}]: {len(raw_vals)} total → {n_kept} after clip, "
                f"mean = {clipped_mean:.6f}, std = {clipped_std:.6f}"
            )

    bounds = compute_bounds_std(merged_stats, std_mult)

    if print_stats:
        header = (
            f"{'Dimension':<25} {'Mean':>12} {'Std':>12} {'Lower':>12} {'Upper':>12} {'Count':>12}"
        )
        section_keys = sorted((k, v) for k, v in bounds.items())
        if section_keys:
            print(f"\n  [relative_tracking_error]  ({len(section_keys)} dims)")
            print(f"  {header}")
            print(f"  {'-' * 87}")
            for key, (lower, upper, mean, std) in section_keys:
                ag, dim = key
                count = merged_stats[key]["count"]
                print(
                    f"  {f'{ag}[{dim}]':<25}"
                    f"{mean:>12.6f} {std:>12.6f} {lower:>12.6f} {upper:>12.6f} {count:>12}"
                )

    print("\n[Pass 2] Finding outliers...")

    worker_args = [
        (
            str(f),
            action_groups,
            skip_first_n_steps,
            bounds,
            min_joint_pos_rel_magnitude,
            std_mult_negative_only,
        )
        for f in h5_files
    ]
    all_outliers = []
    with Pool(num_workers) as pool:
        for outliers in tqdm(
            pool.imap_unordered(_find_outliers_worker, worker_args, chunksize=20),
            total=len(worker_args),
            desc="Outliers",
        ):
            all_outliers.extend(outliers)

    print(f"\nFound {len(all_outliers)} outlier values (>{std_mult} std from mean)")

    for d in all_outliers:
        d["body_part"] = d["action_group"].removeprefix("reltrack_")

    # Group by episode
    outlier_episodes: dict = defaultdict(
        lambda: {
            "timesteps": [],
            "body_parts": [],
            "dims": [],
            "values": [],
            "std_aways": [],
        }
    )

    for d in all_outliers:
        episode_key = (d["house"], d["h5_file"], d["traj_idx"])
        outlier_episodes[episode_key]["timesteps"].append(d["timestep"])
        outlier_episodes[episode_key]["body_parts"].append(d["body_part"])
        outlier_episodes[episode_key]["dims"].append(d["action_dim"])
        outlier_episodes[episode_key]["values"].append(d["value"])
        outlier_episodes[episode_key]["std_aways"].append(d["std_away"])

    for key, episode_dict in outlier_episodes.items():
        episode_dict["house"] = key[0]
        episode_dict["h5_file"] = key[1]
        episode_dict["traj_idx"] = key[2]

    sorted_episodes = sorted(
        outlier_episodes.values(), key=lambda x: len(x["timesteps"]), reverse=True
    )
    return sorted_episodes, bounds

resolve_asset_id

resolve_asset_id(object_name: str, task_config, scene_dataset: str | None = None, data_split: str | None = None, house_index: int | None = None) -> str | None

Resolve the asset ID (UID) for a task object by name.

Tries two strategies in order:

  1. added_objects (frozen config / JSON benchmark): If the object was dynamically added to the scene (e.g. a place receptacle), its XML path is stored in task_config.added_objects. The UID is the stem of the XML filename (<uid>.xml).

  2. Scene metadata (via SceneMeta): For objects that are part of the base scene (e.g. pickup objects), the asset_id is looked up from the scene's *_metadata.json using scene_dataset, data_split, and house_index.

Parameters:

Name Type Description Default
object_name str

MuJoCo body name of the object (e.g. "Mug_12" or "place_receptacle/Bowl_25").

required
task_config

A task config object (e.g. PickAndPlaceTaskConfig) that has an added_objects dict mapping object names to relative XML paths.

required
scene_dataset str | None

Scene dataset name (e.g. "procthor-objaverse"). Required for the SceneMeta fallback.

None
data_split str | None

Data split (e.g. "val"). Required for the SceneMeta fallback.

None
house_index int | None

House index within the dataset/split. Required for the SceneMeta fallback.

None

Returns:

Type Description
str | None

The asset UID string, or None if it could not be resolved.

Source code in molmo_spaces/utils/benchmark_utils.py
def resolve_asset_id(
    object_name: str,
    task_config,
    scene_dataset: str | None = None,
    data_split: str | None = None,
    house_index: int | None = None,
) -> str | None:
    """Resolve the asset ID (UID) for a task object by name.

    Tries two strategies in order:

    1. **added_objects** (frozen config / JSON benchmark): If the object was
       dynamically added to the scene (e.g. a place receptacle), its XML path
       is stored in ``task_config.added_objects``. The UID is the stem of the
       XML filename (``<uid>.xml``).

    2. **Scene metadata** (via SceneMeta): For objects that are part of the
       base scene (e.g. pickup objects), the asset_id is looked up from the
       scene's ``*_metadata.json`` using ``scene_dataset``, ``data_split``,
       and ``house_index``.

    Args:
        object_name: MuJoCo body name of the object (e.g. ``"Mug_12"``
            or ``"place_receptacle/Bowl_25"``).
        task_config: A task config object (e.g. ``PickAndPlaceTaskConfig``)
            that has an ``added_objects`` dict mapping object names to
            relative XML paths.
        scene_dataset: Scene dataset name (e.g. ``"procthor-objaverse"``).
            Required for the SceneMeta fallback.
        data_split: Data split (e.g. ``"val"``). Required for the SceneMeta
            fallback.
        house_index: House index within the dataset/split. Required for the
            SceneMeta fallback.

    Returns:
        The asset UID string, or ``None`` if it could not be resolved.
    """
    # Strategy 1: look up in added_objects (dynamically added assets like receptacles)
    if isinstance(task_config, dict):
        added_objects = task_config.get("added_objects") or {}
    else:
        added_objects = getattr(task_config, "added_objects", None) or {}
    if object_name in added_objects:
        xml_rel_path = added_objects[object_name]
        uid = Path(xml_rel_path).stem
        return uid

    # Strategy 2: look up in scene metadata via SceneMeta
    if scene_dataset is not None and data_split is not None and house_index is not None:
        from molmo_spaces.molmo_spaces_constants import get_scenes
        from molmo_spaces.utils.lazy_loading_utils import install_scene_from_path
        from molmo_spaces.utils.scene_metadata_utils import SceneMeta

        scene_source = get_scenes(scene_dataset, data_split)[data_split][house_index]
        if isinstance(scene_source, dict):
            scene_source = scene_source["ceiling"]
        assert scene_source.endswith(".xml")
        install_scene_from_path(scene_source)

        scene_meta = SceneMeta.get_scene_metadata(scene_source)
        if scene_meta is not None:
            asset_id = scene_meta.get("objects", {}).get(object_name, {}).get("asset_id")
            if asset_id is not None:
                return asset_id

    log.warning(f"Could not resolve asset_id for '{object_name}'")
    return None

save_outlier_gifs

save_outlier_gifs(outlier_episodes: list[dict], output_dir: Path | str, merge_gap: int = 20, context_frames: int = 5, camera_preference: tuple[str, ...] = ('exo_camera_1', 'wrist_camera'), fps_gif: float = 5.0, sample_rate: float = 0.1, max_samples: int = 50) -> int

Save GIFs of merged outlier segments for visual inspection.

Outlier timesteps within merge_gap frames of each other are merged into a single segment (iteratively, until no more merges are possible). Each segment is then padded by context_frames on both sides and saved as one GIF.

Filename convention::

{house}_{batch}_traj{idx}_f{start}-{end}_{max_std:.1f}std_{signal}.gif

Parameters:

Name Type Description Default
outlier_episodes list[dict]

Output of :func:episodes_with_kinematics_outliers.

required
output_dir Path | str

Directory to write GIF files into (created if needed).

required
merge_gap int

Maximum frame distance between two outlier timesteps for them to be merged into the same segment.

20
context_frames int

Number of extra frames to include before the first and after the last outlier in each segment.

5
camera_preference tuple[str, ...]

Ordered list of camera names to try when looking for a video reference inside the H5 file.

('exo_camera_1', 'wrist_camera')
fps_gif float

Playback speed of the output GIF (frames per second).

5.0
sample_rate float

relative amount of examples to render

0.1
max_samples int

absolute max number of samples to save (only applied if sample_rate < 1.0).

50

Returns:

Type Description
int

Number of GIF files written.

Source code in molmo_spaces/utils/benchmark_utils.py
def save_outlier_gifs(
    outlier_episodes: list[dict],
    output_dir: Path | str,
    merge_gap: int = 20,
    context_frames: int = 5,
    camera_preference: tuple[str, ...] = ("exo_camera_1", "wrist_camera"),
    fps_gif: float = 5.0,
    sample_rate: float = 0.1,
    max_samples: int = 50,
) -> int:
    """Save GIFs of merged outlier segments for visual inspection.

    Outlier timesteps within *merge_gap* frames of each other are merged
    into a single segment (iteratively, until no more merges are possible).
    Each segment is then padded by *context_frames* on both sides and saved
    as one GIF.

    Filename convention::

        {house}_{batch}_traj{idx}_f{start}-{end}_{max_std:.1f}std_{signal}.gif

    Args:
        outlier_episodes: Output of :func:`episodes_with_kinematics_outliers`.
        output_dir: Directory to write GIF files into (created if needed).
        merge_gap: Maximum frame distance between two outlier timesteps for
            them to be merged into the same segment.
        context_frames: Number of extra frames to include before the first
            and after the last outlier in each segment.
        camera_preference: Ordered list of camera names to try when looking
            for a video reference inside the H5 file.
        fps_gif: Playback speed of the output GIF (frames per second).
        sample_rate: relative amount of examples to render
        max_samples: absolute max number of samples to save
            (only applied if sample_rate < 1.0).

    Returns:
        Number of GIF files written.
    """
    import imageio

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    n_saved = 0

    if sample_rate < 1.0 and outlier_episodes:
        rlist = outlier_episodes[:]
        import random

        random.shuffle(rlist)
        outlier_episodes = rlist[: min(max_samples, max(1, int(len(rlist) * sample_rate)))]

    for episode in tqdm(outlier_episodes, desc="Saving outlier GIFs"):
        h5_path = Path(episode["h5_file"])
        traj_idx = episode["traj_idx"]
        house = episode["house"]
        timesteps = episode["timesteps"]
        std_aways = episode["std_aways"]

        # Derive batch label from H5 filename
        # e.g. "trajectories_batch_0_of_1.h5" -> "batch_0_of_1"
        batch_label = h5_path.stem.replace("trajectories", "").lstrip("_") or "0"

        # --- locate the video file ---
        # Try 1: read video filename from H5 sensor_data metadata
        video_path = None
        try:
            with h5py.File(h5_path, "r") as f:
                traj_grp = f.get(f"traj_{traj_idx}")
                if traj_grp is not None:
                    sd_grp = None
                    obs_grp = traj_grp.get("obs")
                    if obs_grp is not None:
                        sd_grp = obs_grp.get("sensor_data")

                    if sd_grp is not None:
                        camera_name = None
                        for pref in camera_preference:
                            if pref in sd_grp:
                                camera_name = pref
                                break
                        if camera_name is None:
                            candidates = list(sd_grp.keys())
                            if candidates:
                                camera_name = candidates[0]

                        if camera_name is not None:
                            byte_arr = sd_grp[camera_name][:]
                            video_filename = byte_arr.tobytes().decode("utf-8").rstrip("\x00")
                            video_path = h5_path.parent / video_filename
        except Exception:
            pass

        # Try 2: construct expected path from H5 filename + traj index
        # Pattern: episode_{traj_idx:08d}_{camera}{suffix}.mp4
        if video_path is None or not video_path.exists():
            suffix = h5_path.stem.replace("trajectories", "")  # e.g. "_batch_1_of_2"
            house_dir = h5_path.parent
            for cam in camera_preference:
                candidate = house_dir / f"episode_{traj_idx:08d}_{cam}{suffix}.mp4"
                if candidate.exists():
                    video_path = candidate
                    break

        # Try 3: glob for any MP4 matching this episode index
        if video_path is None or not video_path.exists():
            pattern = f"episode_{traj_idx:08d}_*{suffix}.mp4"
            matches = sorted(house_dir.glob(pattern))
            if matches:
                video_path = matches[0]

        if video_path is None or not video_path.exists():
            print(
                f"  Warning: no video found for traj_{traj_idx} of {h5_path}"
                f" (tried H5 metadata, constructed paths, and glob)"
            )
            continue

        # --- read video frames once per episode ---
        try:
            reader = imageio.get_reader(str(video_path), "ffmpeg")
            frames = [frame for frame in reader]
            reader.close()
        except Exception as e:
            print(f"  Warning: failed to read video {video_path}: {e}")
            continue

        n_frames = len(frames)
        if n_frames == 0:
            print(f"  Warning: video has 0 frames: {video_path}")
            continue

        # Per timestep, keep the worst (highest std_away)
        per_step: dict[int, float] = {}
        for t, sa in zip(timesteps, std_aways):
            if t not in per_step or sa > per_step[t]:
                per_step[t] = sa

        # Merge nearby outlier timesteps into segments (gap <= merge_gap)
        sorted_steps = sorted(per_step.keys())
        segments: list[list[int]] = []
        for t in sorted_steps:
            if segments and t - segments[-1][-1] <= merge_gap:
                segments[-1].append(t)
            else:
                segments.append([t])

        duration_ms = int(1000.0 / fps_gif)

        for seg_steps in segments:
            max_std = max(per_step[t] for t in seg_steps)

            # Frame range with context padding
            seg_start = max(0, seg_steps[0] - context_frames)
            seg_end = min(n_frames - 1, seg_steps[-1] + context_frames)
            clip = frames[seg_start : seg_end + 1]
            if not clip:
                continue

            gif_name = (
                f"{house}_{batch_label}_traj{traj_idx}_f{seg_start}-{seg_end}_{max_std:.1f}std.gif"
            )
            gif_path = output_dir / gif_name
            try:
                imageio.mimsave(str(gif_path), clip, duration=duration_ms, loop=0)
                n_saved += 1
            except Exception as e:
                print(f"  Warning: failed to save {gif_path}: {e}")

    print(f"Saved {n_saved} outlier GIFs to {output_dir}")
    return n_saved

save_signal_histograms

save_signal_histograms(data_path: Path | str, output_dir: Path | str, bounds: dict[tuple[str, int], tuple[float, float, float, float]] | None = None, action_groups: Collection[str] = ('arm',), skip_first_n_steps: int = 1, max_files: int | None = None, num_workers: int = 32, min_joint_pos_rel_magnitude: float = 0.015) -> None

Collect raw relative tracking error values and save histograms as PNGs.

Generates one figure per eps threshold level, each with one subplot per joint dimension. If bounds is provided the outlier thresholds are drawn as vertical lines.

Also generates two joint-histogram (contour) plots of joint_pos_rel vs relative tracking error: one with a tight x-range and one with a wide range showing the chosen min_joint_pos_rel_magnitude threshold.

Source code in molmo_spaces/utils/benchmark_utils.py
def save_signal_histograms(
    data_path: Path | str,
    output_dir: Path | str,
    bounds: dict[tuple[str, int], tuple[float, float, float, float]] | None = None,
    action_groups: Collection[str] = ("arm",),
    skip_first_n_steps: int = 1,
    max_files: int | None = None,
    num_workers: int = 32,
    min_joint_pos_rel_magnitude: float = 1.5e-2,  # TODO ensure it matches the one in episodes_with_kinematics_outliers
) -> None:
    """Collect raw relative tracking error values and save histograms as PNGs.

    Generates one figure per eps threshold level, each with one subplot per
    joint dimension.  If *bounds* is provided the outlier thresholds are
    drawn as vertical lines.

    Also generates two joint-histogram (contour) plots of ``joint_pos_rel``
    vs relative tracking error: one with a tight x-range and one with a
    wide range showing the chosen *min_joint_pos_rel_magnitude* threshold.
    """
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    data_path = Path(data_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    h5_files = sorted(data_path.glob("house_*/trajectories*.h5"))
    if max_files:
        h5_files = h5_files[:max_files]

    print(f"\n[Histogram pass] Collecting raw values from {len(h5_files)} H5 files ...")
    worker_args = [
        (
            str(f),
            action_groups,
            skip_first_n_steps,
        )
        for f in h5_files
    ]

    merged_values: dict[tuple[str, int], list[float]] = defaultdict(list)
    with Pool(num_workers) as pool:
        for wv in tqdm(
            pool.imap_unordered(_collect_values_worker, worker_args, chunksize=20),
            total=len(worker_args),
            desc="Values",
        ):
            for key, vals in wv.items():
                merged_values[key].extend(vals)

    # Group keys by dimension (exclude helper keys)
    dim_map: dict[int, tuple[str, int]] = {}
    for ag, dim in sorted(merged_values.keys()):
        if ag.startswith(("reltrack_denom_", "reltrack_jpr_")):
            continue
        dim_map[dim] = (ag, dim)

    n_dims = len(dim_map)
    if n_dims == 0:
        print("No relative tracking error data found for histograms.")
        return

    reltrack_eps_levels = [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 1.5e-2, 3e-2, 6e-2, 8e-2, 1e-1]

    for hit, eps in enumerate(reltrack_eps_levels):
        ncols = min(n_dims, 4)
        nrows = (n_dims + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False)
        fig.suptitle(
            f"Relative tracking error  (min |cmd_delta| ≥ {eps:.1e})",
            fontsize=14,
            fontweight="bold",
        )

        for idx, (_dim, (ag, dim_idx)) in enumerate(sorted(dim_map.items())):
            ax = axes[idx // ncols][idx % ncols]
            ratios = np.array(merged_values.get((ag, dim_idx), []))
            denoms = np.array(
                merged_values.get((f"reltrack_denom_{ag.removeprefix('reltrack_')}", dim_idx), [])
            )

            if len(ratios) == 0 or len(denoms) == 0:
                ax.set_title(f"{ag}[{dim_idx}] (no data)")
                continue

            mask_eps = denoms >= eps
            filtered = ratios[mask_eps]
            n_total = len(filtered)

            if n_total == 0:
                ax.set_title(f"{ag}[{dim_idx}] (0 samples at eps={eps:.1e})")
                continue

            key = (ag, dim_idx)
            if bounds and key in bounds:
                _, _, b_mean, b_std = bounds[key]
                plot_lo = b_mean - 10.0 * b_std
                plot_hi = b_mean + 10.0 * b_std
            else:
                plot_lo, plot_hi = -1.5, 1.5

            mask_range = (filtered >= plot_lo) & (filtered <= plot_hi)
            n_clipped = int(np.sum(~mask_range))
            plot_vals = filtered[mask_range]

            n_bins = max(10, int(np.sqrt(len(plot_vals)))) if len(plot_vals) > 0 else 10
            ax.hist(
                plot_vals,
                bins=n_bins,
                alpha=0.7,
                edgecolor="black",
                linewidth=0.3,
            )
            ax.set_xlabel("value")
            ax.set_ylabel("count")
            ax.set_xlim(plot_lo, plot_hi)

            if bounds and key in bounds:
                lower, upper, mean, std = bounds[key]
                ax.axvline(mean, color="green", linestyle="-", linewidth=1.2, label="mean")
                ax.axvline(lower, color="red", linestyle="--", linewidth=1.2, label="lower")
                ax.axvline(upper, color="red", linestyle="--", linewidth=1.2, label="upper")
                ax.legend(fontsize=7)
                title = f"{ag}[{dim_idx}]  (N={n_total:,}) {std=:.3f}"
            else:
                title = f"{ag}[{dim_idx}]  (N={n_total:,})"
            if n_clipped:
                title += f"  [{n_clipped} outside plot range]"
            ax.set_title(title, fontsize=9)

        for idx in range(n_dims, nrows * ncols):
            axes[idx // ncols][idx % ncols].set_visible(False)

        fig.tight_layout(rect=[0, 0, 1, 0.93])
        eps_label = f"{eps:.1e}".replace("+", "").replace("-", "m")
        pdf_path = output_dir / f"{hit}_histogram_relative_tracking_error_eps{eps_label}.pdf"
        fig.savefig(pdf_path)
        plt.close(fig)
        print(f"Saved histogram: {pdf_path}")

    # --- Joint histograms (contour) of jpr vs relative tracking error ---
    jpr_floor = 5e-6

    # Precompute per-dimension data (shared by both plots)
    contour_data: dict[int, tuple[str, int, np.ndarray, np.ndarray]] = {}
    for dim, (ag, dim_idx) in sorted(dim_map.items()):
        ratios = np.array(merged_values.get((ag, dim_idx), []))
        jpr_signed = np.array(
            merged_values.get((f"reltrack_jpr_{ag.removeprefix('reltrack_')}", dim_idx), [])
        )
        if len(ratios) == 0 or len(jpr_signed) == 0:
            continue
        mask = np.abs(jpr_signed) >= jpr_floor
        x, y = jpr_signed[mask], ratios[mask]
        if len(x) > 0:
            contour_data[dim] = (ag, dim_idx, x, y)

    contour_configs = [
        {
            "suffix": "tight",
            "title": "Joint histogram: jpr vs reltrack (tight)",
            "x_range": lambda x: (-0.01, 0.01),
            "show_threshold": False,
        },
        {
            "suffix": "wide",
            "title": "Joint histogram: jpr vs reltrack (wide)",
            "x_range": lambda x: (np.percentile(x, 0.25), np.percentile(x, 99.75)),
            "show_threshold": True,
        },
    ]

    for cfg in contour_configs:
        ncols = min(n_dims, 4)
        nrows = (n_dims + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False)
        fig.suptitle(cfg["title"], fontsize=14, fontweight="bold")

        for idx, (dim, (ag, dim_idx)) in enumerate(sorted(dim_map.items())):
            ax = axes[idx // ncols][idx % ncols]
            if dim not in contour_data:
                ax.set_title(f"{ag}[{dim_idx}] (no data)")
                continue

            _, _, x, y = contour_data[dim]
            x_lo, x_hi = cfg["x_range"](x)

            n_bins_2d = max(10, int(np.sqrt(np.sqrt(len(x)))))
            x_edges = np.linspace(x_lo, x_hi, n_bins_2d + 1)
            y_edges = np.linspace(-1.5, 1.5, n_bins_2d + 1)

            hist2d, xe, ye = np.histogram2d(x, y, bins=[x_edges, y_edges])
            xc = 0.5 * (xe[:-1] + xe[1:])
            yc = 0.5 * (ye[:-1] + ye[1:])
            max_val = hist2d.max()
            if max_val > 0:
                levels = np.linspace(max_val * 0.05, max_val, 12)
                cs = ax.contour(xc, yc, hist2d.T, levels=levels, cmap="viridis", linewidths=0.8)
                ax.clabel(cs, inline=True, fontsize=6, fmt="%1.0f")

            if cfg["show_threshold"]:
                thr = min_joint_pos_rel_magnitude
                ax.axvline(-thr, color="red", linestyle="--", linewidth=1.2, label=f{thr:.1e}")
                ax.axvline(thr, color="red", linestyle="--", linewidth=1.2)
                key = (ag, dim_idx)
                if bounds and key in bounds:
                    lower, upper, _, _ = bounds[key]
                    ax.axhline(
                        lower,
                        color="purple",
                        linestyle="--",
                        linewidth=1.2,
                        label=f"lower ({lower:.2f})",
                    )
                    ax.axhline(
                        upper,
                        color="purple",
                        linestyle="--",
                        linewidth=1.2,
                        label=f"upper ({upper:.2f})",
                    )
                ax.legend(fontsize=7, loc="upper right")

            ax.set_xlabel("joint_pos_rel")
            ax.set_ylabel("relative tracking error")
            ax.set_title(f"{ag}[{dim_idx}]  (N={len(x):,})", fontsize=9)

        for idx in range(n_dims, nrows * ncols):
            axes[idx // ncols][idx % ncols].set_visible(False)

        fig.tight_layout(rect=[0, 0, 1, 0.93])
        pdf_path = output_dir / f"joint_histogram_jpr_vs_reltrack_{cfg['suffix']}.pdf"
        fig.savefig(pdf_path)
        plt.close(fig)
        print(f"Saved joint histogram: {pdf_path}")

camera_utils

Functions:

Name Description
erode_segmentation_mask

Apply binary erosion to a segmentation mask.

normalize_points

Normalize image points to 0-1 range, optionally applying distortion correction.

erode_segmentation_mask

erode_segmentation_mask(mask: ndarray, iterations: int = 2) -> ndarray

Apply binary erosion to a segmentation mask.

Parameters:

Name Type Description Default
mask ndarray

Binary segmentation mask

required
iterations int

Number of erosion iterations

2

Returns:

Type Description
ndarray

Eroded binary mask

Source code in molmo_spaces/utils/camera_utils.py
def erode_segmentation_mask(mask: np.ndarray, iterations: int = 2) -> np.ndarray:
    """Apply binary erosion to a segmentation mask.

    Args:
        mask: Binary segmentation mask
        iterations: Number of erosion iterations

    Returns:
        Eroded binary mask
    """
    return binary_erosion(mask, iterations=iterations)

normalize_points

normalize_points(points: ndarray, img_width: int, img_height: int, distortion_map: ndarray | None = None) -> ndarray

Normalize image points to 0-1 range, optionally applying distortion correction.

Parameters:

Name Type Description Default
points ndarray

Array of shape (N, 2) containing (x, y) pixel coordinates

required
img_width int

Image width in pixels

required
img_height int

Image height in pixels

required
distortion_map ndarray | None

Optional distortion map for warped cameras (e.g., GoPro) Currently not implemented - will be added in future

None

Returns:

Type Description
ndarray

Normalized points in 0-1 range as array of shape (N, 2)

Source code in molmo_spaces/utils/camera_utils.py
def normalize_points(
    points: np.ndarray,
    img_width: int,
    img_height: int,
    distortion_map: np.ndarray | None = None,
) -> np.ndarray:
    """Normalize image points to 0-1 range, optionally applying distortion correction.

    Args:
        points: Array of shape (N, 2) containing (x, y) pixel coordinates
        img_width: Image width in pixels
        img_height: Image height in pixels
        distortion_map: Optional distortion map for warped cameras (e.g., GoPro)
                       Currently not implemented - will be added in future

    Returns:
        Normalized points in 0-1 range as array of shape (N, 2)
    """
    # Apply distortion correction if provided
    if distortion_map is not None:
        raise NotImplementedError("Distortion map correction not yet implemented")

    # Normalize to 0-1 range
    normalized_points = points.copy().astype(np.float32)
    normalized_points[:, 0] /= img_width  # x coordinate
    normalized_points[:, 1] /= img_height  # y coordinate

    return normalized_points

constants

Modules:

Name Description
camera_constants

Camera hardware constants for fisheye warping and image processing.

object_constants
simulation_constants

camera_constants

Camera hardware constants for fisheye warping and image processing.

Attributes:

Name Type Description
DEFAULT_CROP_PERCENT
DEFAULT_DISTORTION_PARAMETERS
GOPRO_CAMERA_HEIGHT
GOPRO_CAMERA_WIDTH
GOPRO_VERTICAL_FOV
MODEL_43_HEIGHT
MODEL_43_WIDTH
NULL_DISTORTION_PARAMETERS
DEFAULT_CROP_PERCENT module-attribute
DEFAULT_CROP_PERCENT = 0.3
DEFAULT_DISTORTION_PARAMETERS module-attribute
DEFAULT_DISTORTION_PARAMETERS = {'k1': 0.051, 'k2': 0.144, 'k3': 0.015, 'k4': -0.018}
GOPRO_CAMERA_HEIGHT module-attribute
GOPRO_CAMERA_HEIGHT = 480
GOPRO_CAMERA_WIDTH module-attribute
GOPRO_CAMERA_WIDTH = 640
GOPRO_VERTICAL_FOV module-attribute
GOPRO_VERTICAL_FOV = 139.0
MODEL_43_HEIGHT module-attribute
MODEL_43_HEIGHT = 240
MODEL_43_WIDTH module-attribute
MODEL_43_WIDTH = 320
NULL_DISTORTION_PARAMETERS module-attribute
NULL_DISTORTION_PARAMETERS = {'k1': 0.0, 'k2': 0.0, 'k3': 0.0, 'k4': 0.0}

object_constants

Functions:

Name Description
bad_asset_ids

Attributes:

Name Type Description
AI2THOR_OBJECT_TYPE_TO_MOST_SPECIFIC_WORDNET_LEMMA
AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET
ALL_ARTICULATION_TYPES_THOR
ALL_PICKUP_SYNSETS
ALL_PICKUP_TYPES_THOR
BOOLSET_OBJECT_TYPES
EXTENDED_ARTICULATION_TYPES_THOR
ITHOR_ARTICULATED_OBJECTS
OBJNAV_SYNSETS
OBJNAV_TYPES_THOR
PICKUP_SYNSETS
PICKUP_TYPES_THOR
PICK_AND_PLACE_OBJECTS
RECEPTACLE_SYNSETS
RECEPTACLE_TYPES_THOR
RELATIVE_ANCHOR_TYPES
TARGET_EXCLUDED_SYNSETS
THOR_OBJNAV_OBJECTS_LOWERCASE
THOR_PICKUP_OBJECTS_LOWERCASE
AI2THOR_OBJECT_TYPE_TO_MOST_SPECIFIC_WORDNET_LEMMA module-attribute
AI2THOR_OBJECT_TYPE_TO_MOST_SPECIFIC_WORDNET_LEMMA = {'AlarmClock': 'alarm_clock', 'AluminumFoil': 'aluminum_foil', 'Apple': 'apple', 'AppleSliced': 'apple', 'ArmChair': 'armchair', 'BaseballBat': 'baseball_bat', 'BasketBall': 'basketball', 'Bathtub': 'bathtub', 'BathtubBasin': 'bathtub', 'Bed': 'bed', 'Blinds': 'blind', 'Book': 'book', 'Boots': 'boot', 'Bottle': 'bottle', 'Bowl': 'bowl', 'Box': 'box', 'Bread': 'bread', 'BreadSliced': 'bread', 'ButterKnife': 'butter_knife', 'CD': 'compact_disk', 'Cabinet': 'cabinet', 'Candle': 'candle', 'Cart': 'handcart', 'CellPhone': 'cellular_telephone', 'Chair': 'chair', 'Cloth': 'fabric', 'ClothesDryer': 'clothes_dryer', 'CoffeeMachine': 'coffee_maker', 'CoffeeTable': 'coffee_table', 'CounterTop': 'countertop', 'CreditCard': 'credit_card', 'Cup': 'cup', 'Curtains': 'curtain', 'Desk': 'desk', 'DeskLamp': 'table_lamp', 'Desktop': 'desktop_computer', 'DiningTable': 'dining_table', 'DishSponge': 'sponge', 'DogBed': 'pad', 'Doorframe': 'doorframe', 'Doorway': 'doorway', 'Drawer': 'drawer', 'Dresser': 'chest_of_drawers', 'Dumbbell': 'dumbbell', 'Egg': 'egg', 'EggCracked': 'egg', 'Faucet': 'faucet', 'Floor': 'flooring', 'FloorLamp': 'floor_lamp', 'Footstool': 'footstool', 'Fork': 'fork', 'Fridge': 'refrigerator', 'GarbageBag': 'bin_liner', 'GarbageCan': 'ashcan', 'HandTowel': 'hand_towel', 'HandTowelHolder': 'towel_rack', 'HousePlant': 'houseplant', 'Kettle': 'boiler', 'KeyChain': 'key_ring', 'Knife': 'knife', 'Ladle': 'ladle', 'Laptop': 'laptop', 'LaundryHamper': 'clothes_hamper', 'Lettuce': 'lettuce', 'LettuceSliced': 'lettuce', 'LightSwitch': 'electric_switch', 'Microwave': 'microwave_oven', 'Mirror': 'mirror', 'Mug': 'mug', 'Newspaper': 'newspaper', 'Ottoman': 'pouffe', 'Painting': 'painting', 'Pan': 'cooking_pan', 'PaperTowelRoll': 'paper_towel', 'Pen': 'pen', 'Pencil': 'pencil', 'PepperShaker': 'pepper_shaker', 'Pillow': 'pillow', 'Plate': 'plate', 'Plunger': "plumber's_helper", 'Poster': 'placard', 'Pot': 'pot', 'Potato': 'Irish_potato', 'PotatoSliced': 'Irish_potato', 'RemoteControl': 'remote_control', 'RoomDecor': 'decoration', 'Safe': 'safe', 'SaltShaker': 'saltshaker', 'ScrubBrush': 'scrub_brush', 'Shelf': 'shelf', 'ShelvingUnit': 'shelf', 'ShowerCurtain': 'shower_curtain', 'ShowerDoor': 'door', 'ShowerGlass': 'door', 'ShowerHead': 'showerhead', 'SideTable': 'stand', 'Sink': 'sink', 'SinkBasin': 'sink', 'SoapBar': 'bar_soap', 'SoapBottle': 'soap_dispenser', 'Sofa': 'sofa', 'Spatula': 'spatula', 'Spoon': 'spoon', 'SprayBottle': 'atomizer', 'Statue': 'statue', 'Stool': 'stool', 'StoveBurner': 'burner', 'StoveKnob': 'knob', 'TVStand': 'stand', 'TableTopDecor': 'knickknack', 'TeddyBear': 'teddy_bear', 'Television': 'television', 'TennisRacket': 'tennis_racket', 'TissueBox': 'tissue_paper', 'Toaster': 'toaster', 'Toilet': 'crapper', 'ToiletPaper': 'toilet_tissue', 'ToiletPaperHanger': 'hanger', 'Tomato': 'tomato', 'TomatoSliced': 'tomato', 'Towel': 'towel', 'TowelHolder': 'towel_rack', 'VacuumCleaner': 'vacuum_cleaner', 'Vase': 'vase', 'Wall': 'wall', 'WashingMachine': 'automatic_washer', 'Watch': 'watch', 'WateringCan': 'watering_can', 'Window': 'window', 'WineBottle': 'wine_bottle'}
AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET module-attribute
AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET = {'AlarmClock': 'alarm_clock.n.01', 'AluminumFoil': 'aluminum_foil.n.01', 'Apple': 'apple.n.01', 'AppleSliced': 'apple.n.01', 'ArmChair': 'armchair.n.01', 'BaseballBat': 'baseball_bat.n.01', 'BasketBall': 'basketball.n.02', 'Bathtub': 'bathtub.n.01', 'BathtubBasin': 'bathtub.n.01', 'Bed': 'bed.n.01', 'Blinds': 'blind.n.03', 'Book': 'book.n.02', 'Boots': 'boot.n.01', 'Bottle': 'bottle.n.01', 'Bowl': 'bowl.n.03', 'Box': 'carton.n.02', 'Bread': 'bread.n.01', 'BreadSliced': 'bread.n.01', 'ButterKnife': 'butter_knife.n.01', 'CD': 'compact_disk.n.01', 'Cabinet': 'cabinet.n.01', 'Candle': 'candle.n.01', 'Cart': 'handcart.n.01', 'CellPhone': 'cellular_telephone.n.01', 'Chair': 'straight_chair.n.01', 'Cloth': 'fabric.n.01', 'ClothesDryer': 'clothes_dryer.n.01', 'CoffeeMachine': 'coffee_maker.n.01', 'CoffeeTable': 'coffee_table.n.01', 'CounterTop': 'countertop.n.01', 'CreditCard': 'credit_card.n.01', 'Cup': 'cup.n.01', 'Curtains': 'curtain.n.01', 'Desk': 'desk.n.01', 'DeskLamp': 'table_lamp.n.01', 'Desktop': 'desktop_computer.n.01', 'DiningTable': 'dining_table.n.01', 'DishSponge': 'sponge.n.01', 'DogBed': 'pad.n.04', 'Doorframe': 'doorframe.n.01', 'Doorway': 'doorway.n.01', 'Drawer': 'drawer.n.01', 'Dresser': 'chest_of_drawers.n.01', 'Dumbbell': 'dumbbell.n.01', 'Egg': 'egg.n.02', 'EggCracked': 'egg.n.02', 'Faucet': 'faucet.n.01', 'Floor': 'floor.n.01', 'FloorLamp': 'floor_lamp.n.01', 'Footstool': 'footstool.n.01', 'Fork': 'fork.n.01', 'Fridge': 'refrigerator.n.01', 'GarbageBag': 'bin_liner.n.01', 'GarbageCan': 'ashcan.n.01', 'HandTowel': 'hand_towel.n.01', 'HandTowelHolder': 'towel_rack.n.01', 'HousePlant': 'houseplant.n.01', 'Kettle': 'kettle.n.01', 'KeyChain': 'key_ring.n.01', 'Knife': 'knife.n.01', 'Ladle': 'ladle.n.01', 'Laptop': 'laptop.n.01', 'LaundryHamper': 'clothes_hamper.n.01', 'Lettuce': 'lettuce.n.03', 'LettuceSliced': 'lettuce.n.03', 'LightSwitch': 'switch.n.01', 'Microwave': 'microwave.n.02', 'Mirror': 'mirror.n.01', 'Mug': 'mug.n.04', 'Newspaper': 'newspaper.n.03', 'Ottoman': 'footstool.n.01', 'Painting': 'painting.n.01', 'Pan': 'pan.n.01', 'PaperTowelRoll': 'paper_towel.n.01', 'Pen': 'pen.n.01', 'Pencil': 'pencil.n.01', 'PepperShaker': 'pepper_shaker.n.01', 'Pillow': 'pillow.n.01', 'Plate': 'plate.n.04', 'Plunger': 'plunger.n.03', 'Poster': 'poster.n.01', 'Pot': 'pot.n.01', 'Potato': 'potato.n.01', 'PotatoSliced': 'potato.n.01', 'RemoteControl': 'remote_control.n.01', 'RoomDecor': 'decoration.n.01', 'Safe': 'safe.n.01', 'SaltShaker': 'saltshaker.n.01', 'ScrubBrush': 'scrub_brush.n.01', 'Shelf': 'shelf.n.01', 'ShelvingUnit': 'shelf.n.01', 'ShowerCurtain': 'shower_curtain.n.01', 'ShowerDoor': 'door.n.01', 'ShowerGlass': 'door.n.01', 'ShowerHead': 'showerhead.n.01', 'SideTable': 'stand.n.04', 'Sink': 'sink.n.01', 'SinkBasin': 'sink.n.01', 'SoapBar': 'bar_soap.n.01', 'SoapBottle': 'soap_dispenser.n.01', 'Sofa': 'sofa.n.01', 'Spatula': 'spatula.n.01', 'Spoon': 'spoon.n.01', 'SprayBottle': 'atomizer.n.01', 'Statue': 'statue.n.01', 'Stool': 'stool.n.01', 'StoveBurner': 'burner.n.02', 'StoveKnob': 'knob.n.02', 'TVStand': 'stand.n.04', 'TableTopDecor': 'knickknack.n.01', 'TeddyBear': 'teddy.n.01', 'Television': 'television_receiver.n.01', 'TennisRacket': 'tennis_racket.n.01', 'TissueBox': 'tissue.n.02', 'Toaster': 'toaster.n.02', 'Toilet': 'toilet.n.02', 'ToiletPaper': 'toilet_tissue.n.01', 'ToiletPaperHanger': 'hanger.n.02', 'Tomato': 'tomato.n.01', 'TomatoSliced': 'tomato.n.01', 'Towel': 'towel.n.01', 'TowelHolder': 'towel_rack.n.01', 'VacuumCleaner': 'vacuum.n.04', 'Vase': 'vase.n.01', 'Wall': 'wall.n.01', 'WashingMachine': 'washer.n.03', 'Watch': 'watch.n.01', 'WateringCan': 'watering_can.n.01', 'Window': 'window.n.01', 'WineBottle': 'wine_bottle.n.01'}
ALL_ARTICULATION_TYPES_THOR module-attribute
ALL_ARTICULATION_TYPES_THOR = ['Toilet', 'Dresser', 'Safe', 'Shelving_Unit', 'ShelvingUnit', 'Side_Table', 'SideTable', 'Fridge', 'Microwave', 'Coffee_Table', 'CoffeeTable', 'Desk', 'Laptop', 'Doorways', 'Laundry_Hamper', 'LaundryHamper']
ALL_PICKUP_SYNSETS module-attribute
ALL_PICKUP_SYNSETS = [(AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET[ot]) for ot in ALL_PICKUP_TYPES_THOR]
ALL_PICKUP_TYPES_THOR module-attribute
ALL_PICKUP_TYPES_THOR = ['AlarmClock', 'AluminumFoil', 'Apple', 'AppleSliced', 'Book', 'Boots', 'Bottle', 'Bowl', 'Box', 'Bread', 'BreadSliced', 'ButterKnife', 'Candle', 'CD', 'CellPhone', 'Cloth', 'CreditCard', 'Cup', 'DishSponge', 'Dumbbell', 'Egg', 'EggCracked', 'Fork', 'HandTowel', 'Kettle', 'KeyChain', 'Knife', 'Ladle', 'Laptop', 'Lettuce', 'LettuceSliced', 'Mug', 'Newspaper', 'Pan', 'PaperTowelRoll', 'Pen', 'Pencil', 'PepperShaker', 'Pillow', 'Plate', 'Plunger', 'Pot', 'Potato', 'PotatoSliced', 'RemoteControl', 'SaltShaker', 'ScrubBrush', 'SoapBar', 'SoapBottle', 'Spatula', 'Spoon', 'SprayBottle', 'Statue', 'TableTopDecor', 'TeddyBear', 'TennisRacket', 'TissueBox', 'ToiletPaper', 'Tomato', 'TomatoSliced', 'Towel', 'Vase', 'Watch', 'WateringCan', 'WineBottle']
BOOLSET_OBJECT_TYPES module-attribute
BOOLSET_OBJECT_TYPES = {'AlarmClock', 'Apple', 'ArmChair', 'BasketBall', 'Bed', 'Book', 'Boots', 'Bottle', 'Bowl', 'Box', 'Bread', 'ButterKnife', 'CD', 'Cabinet', 'Candle', 'Cart', 'CellPhone', 'Chair', 'Cloth', 'ClothesDryer', 'CoffeeMachine', 'CoffeeTable', 'CounterTop', 'CreditCard', 'Cup', 'Desk', 'DeskLamp', 'Desktop', 'DiningTable', 'DishSponge', 'DogBed', 'Drawer', 'Dresser', 'Dumbbell', 'Egg', 'Faucet', 'FloorLamp', 'Fork', 'Fridge', 'GarbageBag', 'GarbageCan', 'HousePlant', 'Kettle', 'KeyChain', 'Knife', 'Ladle', 'Laptop', 'LaundryHamper', 'Lettuce', 'Microwave', 'Mug', 'Newspaper', 'Ottoman', 'Painting', 'Pan', 'PaperTowelRoll', 'Pen', 'Pencil', 'PepperShaker', 'Pillow', 'Plate', 'Plunger', 'Pot', 'Potato', 'RemoteControl', 'Safe', 'SaltShaker', 'Shelf', 'ShelvingUnit', 'SideTable', 'Sink', 'SinkBasin', 'SoapBar', 'SoapBottle', 'Sofa', 'Spatula', 'Spoon', 'SprayBottle', 'Statue', 'Stool', 'TVStand', 'TeddyBear', 'Television', 'TennisRacket', 'TissueBox', 'Toaster', 'Toilet', 'ToiletPaper', 'Tomato', 'VacuumCleaner', 'Vase', 'WashingMachine', 'Watch', 'WineBottle'}
EXTENDED_ARTICULATION_TYPES_THOR module-attribute
EXTENDED_ARTICULATION_TYPES_THOR = ALL_ARTICULATION_TYPES_THOR + ITHOR_ARTICULATED_OBJECTS
ITHOR_ARTICULATED_OBJECTS module-attribute
ITHOR_ARTICULATED_OBJECTS = ['cabinet', 'drawer', 'oven', 'dishwasher', 'showerdoor']
OBJNAV_SYNSETS module-attribute
OBJNAV_SYNSETS = [(AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET[ot]) for ot in OBJNAV_TYPES_THOR]
OBJNAV_TYPES_THOR module-attribute
OBJNAV_TYPES_THOR = ['AlarmClock', 'Apple', 'BasketBall', 'Bed', 'Bowl', 'Chair', 'GarbageCan', 'HousePlant', 'Laptop', 'Mug', 'Sofa', 'SprayBottle', 'Television', 'Toilet', 'Vase']
PICKUP_SYNSETS module-attribute
PICKUP_SYNSETS = [(AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET[ot]) for ot in PICKUP_TYPES_THOR]
PICKUP_TYPES_THOR module-attribute
PICKUP_TYPES_THOR = ['AlarmClock', 'Apple', 'BasketBall', 'Bowl', 'Laptop', 'Mug', 'SprayBottle', 'Vase']
PICK_AND_PLACE_OBJECTS module-attribute
PICK_AND_PLACE_OBJECTS = ['alarm_clock', 'aluminum_foil', 'apple', 'bottle', 'bread', 'butterknife', 'candle', 'cd', 'cellphone', 'cloth', 'creditcard', 'cup', 'dish_sponge', 'egg', 'egg_cracked', 'fork', 'hand_towel', 'keychain', 'knife', 'ladle', 'mug', 'newspaper', 'paper_towel', 'pen', 'pencil', 'pepper_shaker', 'potato', 'remote', 'salt_shaker', 'scrub_brush', 'soap_bar', 'soap_bottle', 'spatula', 'spoon', 'spray_bottle', 'tissue_box', 'toilet_paper', 'toilet_paper_used_up', 'tomato', 'towel_statue', 'watch']
RECEPTACLE_SYNSETS module-attribute
RECEPTACLE_SYNSETS = [(AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET[ot]) for ot in RECEPTACLE_TYPES_THOR]
RECEPTACLE_TYPES_THOR module-attribute
RECEPTACLE_TYPES_THOR = ['ArmChair', 'Bed', 'Chair', 'CoffeeTable', 'CounterTop', 'Desk', 'DiningTable', 'Dresser', 'Shelf', 'SideTable', 'Sofa', 'Stool', 'TVStand']
RELATIVE_ANCHOR_TYPES module-attribute
RELATIVE_ANCHOR_TYPES = {'Bed', 'CounterTop', 'DiningTable', 'Fridge', 'Sink', 'Sofa', 'Television', 'Toilet'}
TARGET_EXCLUDED_SYNSETS module-attribute
TARGET_EXCLUDED_SYNSETS = {'knickknack.n.01', 'countertop.n.01', 'doorway.n.01', 'shelf.n.01', 'decoration.n.01', 'window.n.01', 'doorframe.n.01', 'wall.n.01', 'drawer.n.01', 'floor.n.01', 'arch.n.03', 'needle.n.03', 'tank_car.n.01', 'swatch.n.01', 'visor.n.01', 'arrow.n.01', 'plug.n.01', 'lung.n.01', 'organ.n.01', 'monocle.n.01', 'power_tool.n.01', 'logo.n.01', 'spark_plug.n.01', 'optical_illusion.n.01', 'prize.n.01', 'window.n.04', 'window.n.07', 'projector.n.01', 'pack.n.09'}
THOR_OBJNAV_OBJECTS_LOWERCASE module-attribute
THOR_OBJNAV_OBJECTS_LOWERCASE = [(replace('_', '')) for x in THOR_OBJNAV_OBJECTS_LOWERCASE]
THOR_PICKUP_OBJECTS_LOWERCASE module-attribute
THOR_PICKUP_OBJECTS_LOWERCASE = [(replace('_', '')) for x in THOR_PICKUP_OBJECTS_LOWERCASE]
bad_asset_ids
bad_asset_ids()
Source code in molmo_spaces/utils/constants/object_constants.py
def bad_asset_ids():
    import prior

    global _cached_bad_asset_ids
    if _cached_bad_asset_ids is None:
        _cached_bad_asset_ids = prior.load_dataset(
            dataset="vida-additional-references", revision="asset-mismatches"
        )
    return _cached_bad_asset_ids

simulation_constants

Attributes:

Name Type Description
OBJAVERSE_FREE_JOINT_DEFAULT_DAMPING
OBJAVERSE_FREE_JOINT_DEFAULT_DAMPING module-attribute
OBJAVERSE_FREE_JOINT_DEFAULT_DAMPING = 0.001

controller_utils

Functions:

Name Description
find_nearest_equivalent_angle
optimize_all_steer_and_drive
optimize_steer_and_drive

find_nearest_equivalent_angle

find_nearest_equivalent_angle(curr, target, steer_angle_range)
Source code in molmo_spaces/utils/controller_utils.py
def find_nearest_equivalent_angle(curr, target, steer_angle_range):
    k_min = int(np.floor((steer_angle_range[0] - target) / (2 * np.pi)))
    k_max = int(np.ceil((steer_angle_range[1] - target) / (2 * np.pi)))
    candidates = [target + k * 2 * np.pi for k in range(k_min, k_max + 1)]
    candidates = [c for c in candidates if steer_angle_range[0] <= c <= steer_angle_range[1]]
    if not candidates:
        return np.clip(target, steer_angle_range[0], steer_angle_range[1])
    candidates = np.array(candidates)
    idx = np.argmin(np.abs(candidates - curr))
    return candidates[idx]

optimize_all_steer_and_drive

optimize_all_steer_and_drive(current_angles, target_angles, target_speeds, steer_angle_range, max_wheel_speed)
Source code in molmo_spaces/utils/controller_utils.py
def optimize_all_steer_and_drive(
    current_angles, target_angles, target_speeds, steer_angle_range, max_wheel_speed
):
    optimized = [
        optimize_steer_and_drive(c, t, s, steer_angle_range, max_wheel_speed)
        for c, t, s in zip(current_angles, target_angles, target_speeds)
    ]
    angles, speeds = zip(*optimized)
    speeds = np.array(speeds)
    if np.any(np.abs(speeds) > max_wheel_speed):
        factor = max_wheel_speed / np.max(np.abs(speeds))
        speeds = speeds * factor
    return np.array(angles), speeds

optimize_steer_and_drive

optimize_steer_and_drive(curr, target, speed, steer_angle_range)
Source code in molmo_spaces/utils/controller_utils.py
def optimize_steer_and_drive(curr, target, speed, steer_angle_range):
    a = find_nearest_equivalent_angle(curr, target, steer_angle_range)
    cost_a = abs(a - curr)
    b = find_nearest_equivalent_angle(curr, target + np.pi, steer_angle_range)
    cost_b = abs(b - curr)
    if cost_a <= cost_b:
        return a, speed
    else:
        return b, -speed

depth_utils

Utilities for depth image encoding and decoding.

Optimized for Intel RealSense D405 camera specs: - D405 actual spec: 7cm - 50cm range, ±1.4% at 20cm = ±2.8mm - Encoding range: 5cm - 55cm (extended for margin) - Resolution: 1280x720 - Baseline: 18mm, Global shutter

Depth images are encoded as 16-bit values across RG channels: 1. High precision: 7.6 microns over 50cm range (65,534 discrete values for valid data) 2. Video compatibility: Standard RGB video codecs (H.264 RGB) 3. Efficient lossy compression: Unused B channel reduces artifacts 4. Smaller file sizes vs 24-bit encoding 5. Invalid data handling: 0 reserved for missing/out-of-range pixels

The encoding range (5-55cm) extends slightly beyond D405's spec (7-50cm) to: - Provide margin for edge cases and measurement noise - Still maintain excellent precision (7.6μm vs 15μm with wider ranges) - Keep compression efficient (tight dynamic range = better lossy codec performance)

Invalid/missing data convention: - Pixels outside [DEPTH_MIN, DEPTH_MAX] are encoded as 0 (not clipped) - This allows easy masking: valid_mask = depth > 0 - Common for far-away regions or sensor failures in real-world depth cameras

Functions:

Name Description
compute_depth_encoding_stats

Compute statistics about depth encoding precision.

decode_depth_from_rgb

Decode RG-encoded depth back to metric depth in meters.

detect_depth_edges

Detect depth discontinuities (edges) where compression artifacts are expected.

encode_depth_to_rgb

Encode metric depth values as 16-bit RG channels for video storage.

load_depth_video

Load depth video and decode frames back to metric depth.

print_depth_stats

Print detailed depth statistics to console.

save_depth_video

Save depth frames as compressed video.

validate_roundtrip_accuracy

Validate that depth encoding/decoding roundtrip is accurate.

visualize_depth_error

Visualize the compression error between original and decoded depth.

visualize_depth_image

Visualize depth image with statistics and save to debug file.

Attributes:

Name Type Description
DEPTH_MAX
DEPTH_MIN
DEPTH_VIDEO_CODEC
DEPTH_VIDEO_CRF
DEPTH_VIDEO_PIXELFORMAT
log

DEPTH_MAX module-attribute

DEPTH_MAX = 0.55

DEPTH_MIN module-attribute

DEPTH_MIN = 0.05

DEPTH_VIDEO_CODEC module-attribute

DEPTH_VIDEO_CODEC = 'libx264rgb'

DEPTH_VIDEO_CRF module-attribute

DEPTH_VIDEO_CRF = '23'

DEPTH_VIDEO_PIXELFORMAT module-attribute

DEPTH_VIDEO_PIXELFORMAT = 'rgb24'

log module-attribute

log = getLogger(__name__)

compute_depth_encoding_stats

compute_depth_encoding_stats(depth_meters: ndarray) -> dict

Compute statistics about depth encoding precision.

Useful for validating that the depth range and encoding are appropriate for your specific use case.

Parameters:

Name Type Description Default
depth_meters ndarray

(H, W) float32 array of depth values in meters

required

Returns:

Type Description
dict

Dictionary with statistics:

dict
  • min_depth: Minimum depth value in the image (excluding zeros)
dict
  • max_depth: Maximum depth value in the image
dict
  • mean_depth: Mean depth value (excluding zeros)
dict
  • invalid_pixels: Number of pixels outside [DEPTH_MIN, DEPTH_MAX] (will be encoded as 0)
dict
  • precision: Theoretical precision in meters (at current encoding)
dict
  • precision_mm: Theoretical precision in millimeters
Source code in molmo_spaces/utils/depth_utils.py
def compute_depth_encoding_stats(depth_meters: np.ndarray) -> dict:
    """Compute statistics about depth encoding precision.

    Useful for validating that the depth range and encoding are appropriate
    for your specific use case.

    Args:
        depth_meters: (H, W) float32 array of depth values in meters

    Returns:
        Dictionary with statistics:
        - min_depth: Minimum depth value in the image (excluding zeros)
        - max_depth: Maximum depth value in the image
        - mean_depth: Mean depth value (excluding zeros)
        - invalid_pixels: Number of pixels outside [DEPTH_MIN, DEPTH_MAX] (will be encoded as 0)
        - precision: Theoretical precision in meters (at current encoding)
        - precision_mm: Theoretical precision in millimeters
    """
    depth_range = DEPTH_MAX - DEPTH_MIN
    precision = depth_range / 65534.0  # 16-bit precision (minus 1 for invalid marker)

    # Count invalid pixels (will be set to 0)
    invalid = np.sum((depth_meters < DEPTH_MIN) | (depth_meters > DEPTH_MAX))

    # Compute stats only on valid pixels for meaningful min/mean
    valid_mask = (depth_meters >= DEPTH_MIN) & (depth_meters <= DEPTH_MAX)
    valid_depths = depth_meters[valid_mask]

    return {
        "min_depth": float(np.min(valid_depths)) if valid_depths.size > 0 else 0.0,
        "max_depth": float(np.max(valid_depths)) if valid_depths.size > 0 else 0.0,
        "mean_depth": float(np.mean(valid_depths)) if valid_depths.size > 0 else 0.0,
        "invalid_pixels": int(invalid),
        "invalid_fraction": float(invalid) / depth_meters.size,
        "precision_meters": precision,
        "precision_mm": precision * 1000.0,
        "depth_range": depth_range,
        "encoding_bits": 16,
    }

decode_depth_from_rgb

decode_depth_from_rgb(rgb_frame: ndarray, validate: bool = True) -> ndarray

Decode RG-encoded depth back to metric depth in meters.

Reverses the encoding from encode_depth_to_rgb() to recover floating-point depth values from uint8 RG channels.

Encoded value of 0 (RGB(0,0,0)) represents invalid/missing data and is decoded to 0.0 meters.

Parameters:

Name Type Description Default
rgb_frame ndarray

(H, W, 3) uint8 array with depth encoded in RG channels

required
validate bool

If True, warns if B channel is non-zero (indicates wrong pixel format)

True

Returns:

Name Type Description
depth_meters ndarray

(H, W) float32 array of depth values in meters. Valid pixels in range [DEPTH_MIN, DEPTH_MAX]. Invalid pixels are 0.0 (use depth > 0 to mask valid data).

Example

depth_original = np.array([[0.5, 1.0], [0.1, 1.0]], dtype=np.float32) rgb = encode_depth_to_rgb(depth_original) depth_decoded = decode_depth_from_rgb(rgb) np.allclose(depth_original, depth_decoded, atol=0.001) True

Source code in molmo_spaces/utils/depth_utils.py
def decode_depth_from_rgb(rgb_frame: np.ndarray, validate: bool = True) -> np.ndarray:
    """Decode RG-encoded depth back to metric depth in meters.

    Reverses the encoding from encode_depth_to_rgb() to recover
    floating-point depth values from uint8 RG channels.

    Encoded value of 0 (RGB(0,0,0)) represents invalid/missing data and
    is decoded to 0.0 meters.

    Args:
        rgb_frame: (H, W, 3) uint8 array with depth encoded in RG channels
        validate: If True, warns if B channel is non-zero (indicates wrong pixel format)

    Returns:
        depth_meters: (H, W) float32 array of depth values in meters.
                     Valid pixels in range [DEPTH_MIN, DEPTH_MAX].
                     Invalid pixels are 0.0 (use depth > 0 to mask valid data).

    Example:
        >>> depth_original = np.array([[0.5, 1.0], [0.1, 1.0]], dtype=np.float32)
        >>> rgb = encode_depth_to_rgb(depth_original)
        >>> depth_decoded = decode_depth_from_rgb(rgb)
        >>> np.allclose(depth_original, depth_decoded, atol=0.001)
        True
    """
    # Validate B channel (should be ~0 for properly encoded depth)
    if validate:
        b_mean = float(np.mean(rgb_frame[:, :, 2]))
        if b_mean > 5.0:
            log.warning(
                f"B channel has non-zero values (mean={b_mean:.1f}). "
                "This may indicate YUV pixel format was used instead of RGB, "
                "which causes chroma subsampling artifacts. "
                "Use load_depth_video() or imageio with pixelformat='rgb24' to avoid this."
            )

    # Reconstruct 16-bit integer from RG channels (ignore B channel)
    depth_16bit = (
        rgb_frame[:, :, 0].astype(np.uint32) * np.uint32(256)  # R * 2^8
        + rgb_frame[:, :, 1].astype(np.uint32)  # G * 2^0
    )

    # Identify valid pixels (encoded as non-zero)
    valid_mask = depth_16bit > 0

    # Decode valid pixels: [1, 65535] → [DEPTH_MIN, DEPTH_MAX]
    # Invalid pixels (0) → 0.0
    depth_normalized = (depth_16bit.astype(np.float32) - 1.0) / 65534.0
    metric_depth = depth_normalized * np.float32(DEPTH_MAX - DEPTH_MIN) + np.float32(DEPTH_MIN)

    # Set invalid pixels to 0.0
    metric_depth = np.where(valid_mask, metric_depth, 0.0)

    return metric_depth.astype(np.float32)

detect_depth_edges

detect_depth_edges(depth: ndarray, gradient_threshold_mm: float = 50.0) -> ndarray

Detect depth discontinuities (edges) where compression artifacts are expected.

Used for analysis/visualization only - not part of encoding/decoding pipeline.

Parameters:

Name Type Description Default
depth ndarray

(H, W) depth array in meters

required
gradient_threshold_mm float

Depth gradient threshold in mm to classify as edge

50.0

Returns:

Name Type Description
edge_mask ndarray

(H, W) boolean array, True at edge pixels

Source code in molmo_spaces/utils/depth_utils.py
def detect_depth_edges(depth: np.ndarray, gradient_threshold_mm: float = 50.0) -> np.ndarray:
    """Detect depth discontinuities (edges) where compression artifacts are expected.

    Used for analysis/visualization only - not part of encoding/decoding pipeline.

    Args:
        depth: (H, W) depth array in meters
        gradient_threshold_mm: Depth gradient threshold in mm to classify as edge

    Returns:
        edge_mask: (H, W) boolean array, True at edge pixels
    """
    try:
        import scipy.ndimage
    except ImportError as e:
        raise ImportError("Edge detection requires scipy. Install with: pip install scipy") from e

    # Compute depth gradients in mm
    depth_mm = depth * 1000.0
    grad_x = scipy.ndimage.sobel(depth_mm, axis=1)
    grad_y = scipy.ndimage.sobel(depth_mm, axis=0)
    gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)

    # Classify as edge if gradient exceeds threshold
    edge_mask = gradient_magnitude > gradient_threshold_mm

    # Dilate edge mask to catch nearby compression artifacts (1-2 pixel radius)
    edge_mask = scipy.ndimage.binary_dilation(edge_mask, iterations=2)

    return edge_mask

encode_depth_to_rgb

encode_depth_to_rgb(depth_meters: ndarray) -> ndarray

Encode metric depth values as 16-bit RG channels for video storage.

Converts floating-point depth values (in meters) to uint8 RG encoding. Provides ~7.6 micron precision over the 50cm range using 16-bit encoding. The B channel is set to 0, which helps with lossy video compression.

Invalid pixels (outside [DEPTH_MIN, DEPTH_MAX]) are encoded as 0, allowing downstream processing to use depth_mask = depth > 0 to identify valid data.

Parameters:

Name Type Description Default
depth_meters ndarray

(H, W) float32 array of depth values in meters. Values outside [DEPTH_MIN, DEPTH_MAX] are set to 0 (invalid).

required

Returns:

Name Type Description
rgb_frame ndarray

(H, W, 3) uint8 array with depth encoded as: - R channel: bits 8-15 (high byte) - G channel: bits 0-7 (low byte) - B channel: 0 (unused, helps compression) - RGB(0,0,0): invalid/missing data

Example

depth = np.array([[0.5, 1.0], [0.1, 1.0]], dtype=np.float32) rgb = encode_depth_to_rgb(depth) rgb.shape (2, 2, 3) rgb.dtype dtype('uint8')

Source code in molmo_spaces/utils/depth_utils.py
def encode_depth_to_rgb(depth_meters: np.ndarray) -> np.ndarray:
    """Encode metric depth values as 16-bit RG channels for video storage.

    Converts floating-point depth values (in meters) to uint8 RG encoding.
    Provides ~7.6 micron precision over the 50cm range using 16-bit encoding.
    The B channel is set to 0, which helps with lossy video compression.

    Invalid pixels (outside [DEPTH_MIN, DEPTH_MAX]) are encoded as 0, allowing
    downstream processing to use `depth_mask = depth > 0` to identify valid data.

    Args:
        depth_meters: (H, W) float32 array of depth values in meters.
                     Values outside [DEPTH_MIN, DEPTH_MAX] are set to 0 (invalid).

    Returns:
        rgb_frame: (H, W, 3) uint8 array with depth encoded as:
                  - R channel: bits 8-15 (high byte)
                  - G channel: bits 0-7 (low byte)
                  - B channel: 0 (unused, helps compression)
                  - RGB(0,0,0): invalid/missing data

    Example:
        >>> depth = np.array([[0.5, 1.0], [0.1, 1.0]], dtype=np.float32)
        >>> rgb = encode_depth_to_rgb(depth)
        >>> rgb.shape
        (2, 2, 3)
        >>> rgb.dtype
        dtype('uint8')
    """
    # Identify valid pixels (within depth range)
    valid_mask = (depth_meters >= DEPTH_MIN) & (depth_meters <= DEPTH_MAX)

    # Set invalid pixels to 0 (missing data sentinel)
    depth_masked = np.where(valid_mask, depth_meters, 0.0)

    # Normalize valid range to [0, 1]
    depth_normalized = (depth_masked - DEPTH_MIN) / (DEPTH_MAX - DEPTH_MIN)

    # Map to [1, 65535] to reserve 0 for invalid data
    # Valid pixels: [DEPTH_MIN, DEPTH_MAX] → [1, 65535]
    # Invalid pixels: 0.0 → 0 (after normalization becomes negative, handled below)
    depth_16bit = np.where(
        valid_mask,
        (depth_normalized * 65534.0 + 1.0).astype(np.uint16),  # Valid: [1, 65535]
        np.uint16(0),  # Invalid: 0
    )

    # Split into RG channels, leave B as 0
    h, w = depth_meters.shape
    rgb_frame = np.zeros((h, w, 3), dtype=np.uint8)
    rgb_frame[:, :, 0] = (depth_16bit >> 8) & 0xFF  # R: bits 8-15
    rgb_frame[:, :, 1] = depth_16bit & 0xFF  # G: bits 0-7
    # rgb_frame[:, :, 2] remains 0                   # B: unused

    return rgb_frame

load_depth_video

load_depth_video(video_path: str | Path, logger: Logger | None = None) -> ndarray

Load depth video and decode frames back to metric depth.

Companion function to save_depth_video(). Ensures proper codec settings for reading depth videos (RGB pixel format, no YUV conversion).

Parameters:

Name Type Description Default
video_path str | Path

Path to the depth video file (.mp4)

required
logger Logger | None

Optional logger for debugging

None

Returns:

Name Type Description
depth_frames ndarray

(T, H, W) float32 array of depth values in meters

Example
Save and load round-trip

depth_original = np.random.rand(10, 480, 640).astype(np.float32) * 0.4 + 0.1 save_depth_video(depth_original, "test_depth.mp4") depth_loaded = load_depth_video("test_depth.mp4") depth_loaded.shape (10, 480, 640)

Source code in molmo_spaces/utils/depth_utils.py
def load_depth_video(
    video_path: str | Path,
    logger: logging.Logger | None = None,
) -> np.ndarray:
    """Load depth video and decode frames back to metric depth.

    Companion function to save_depth_video(). Ensures proper codec settings
    for reading depth videos (RGB pixel format, no YUV conversion).

    Args:
        video_path: Path to the depth video file (.mp4)
        logger: Optional logger for debugging

    Returns:
        depth_frames: (T, H, W) float32 array of depth values in meters

    Example:
        >>> # Save and load round-trip
        >>> depth_original = np.random.rand(10, 480, 640).astype(np.float32) * 0.4 + 0.1
        >>> save_depth_video(depth_original, "test_depth.mp4")
        >>> depth_loaded = load_depth_video("test_depth.mp4")
        >>> depth_loaded.shape
        (10, 480, 640)
    """
    import imageio

    logger = logger or log

    video_path = Path(video_path)
    if not video_path.exists():
        raise FileNotFoundError(f"Video file not found: {video_path}")

    logger.debug(f"Loading depth video from {video_path}")

    # Use imageio with RGB pixel format (critical - avoids YUV conversion artifacts)
    try:
        reader = imageio.get_reader(
            str(video_path),
            format="ffmpeg",
            pixelformat="rgb24",  # Force RGB to avoid chroma subsampling
        )
    except (ImportError, OSError, ValueError) as e:
        logger.warning(f"Failed to open with pixelformat='rgb24': {e}. Trying default...")
        reader = imageio.get_reader(str(video_path), format="ffmpeg")

    decoded_frames = []
    for frame_rgb in reader:
        # Validate frame format
        if frame_rgb.shape[-1] != 3:
            raise ValueError(f"Expected RGB frame with 3 channels, got shape {frame_rgb.shape}")

        decoded_depth = decode_depth_from_rgb(frame_rgb, validate=True)
        decoded_frames.append(decoded_depth)

    reader.close()

    depth_frames = np.array(decoded_frames, dtype=np.float32)

    logger.debug(
        f"Loaded {len(depth_frames)} depth frames: {depth_frames.shape}, "
        f"range [{depth_frames.min():.3f}m, {depth_frames.max():.3f}m]"
    )

    return depth_frames

print_depth_stats

print_depth_stats(depth_meters: ndarray, name: str = 'Depth')

Print detailed depth statistics to console.

Parameters:

Name Type Description Default
depth_meters ndarray

(H, W) float32 array of depth values in meters

required
name str

Name to display in the output (e.g., "Wrist Camera Depth")

'Depth'
Source code in molmo_spaces/utils/depth_utils.py
def print_depth_stats(depth_meters: np.ndarray, name: str = "Depth"):
    """Print detailed depth statistics to console.

    Args:
        depth_meters: (H, W) float32 array of depth values in meters
        name: Name to display in the output (e.g., "Wrist Camera Depth")
    """
    stats = compute_depth_encoding_stats(depth_meters)

    print(f"\n{name} Statistics:")
    print(f"  Range: [{stats['min_depth']:.3f}m, {stats['max_depth']:.3f}m]")
    print(f"  Mean: {stats['mean_depth']:.3f}m")
    print(f"  Encoding precision: {stats['precision_mm']:.4f}mm ({stats['precision_meters']:.6f}m)")
    print(f"  Valid range: {DEPTH_MIN}m - {DEPTH_MAX}m")

    if stats["invalid_pixels"] > 0:
        invalid_below = np.sum(depth_meters < DEPTH_MIN)
        invalid_above = np.sum(depth_meters > DEPTH_MAX)
        print(
            f"  Invalid pixels (will be set to 0): {stats['invalid_pixels']:,} ({stats['invalid_fraction'] * 100:.2f}%)"
        )
        print(f"    Below {DEPTH_MIN}m: {invalid_below:,}")
        print(f"    Above {DEPTH_MAX}m: {invalid_above:,}")

    else:
        print("  All pixels within valid range")

save_depth_video

save_depth_video(depth_frames: ndarray, video_path: str | Path, fps: float = 10, logger: Logger | None = None) -> None

Save depth frames as compressed video.

This is the single source of truth for depth video compression settings. Encodes depth frames using 16-bit RG encoding and saves with configured codec.

Parameters:

Name Type Description Default
depth_frames ndarray

(T, H, W) float32 array of depth values in meters

required
video_path str | Path

Path to save the video file

required
fps float

Frames per second for the video

10
logger Logger | None

Optional logger for debugging

None
Example

depth_frames = np.random.rand(100, 480, 640).astype(np.float32) * 0.5 + 0.3 save_depth_video(depth_frames, "depth.mp4")

Source code in molmo_spaces/utils/depth_utils.py
def save_depth_video(
    depth_frames: np.ndarray,
    video_path: str | Path,
    fps: float = 10,
    logger: logging.Logger | None = None,
) -> None:
    """Save depth frames as compressed video.

    This is the single source of truth for depth video compression settings.
    Encodes depth frames using 16-bit RG encoding and saves with configured codec.

    Args:
        depth_frames: (T, H, W) float32 array of depth values in meters
        video_path: Path to save the video file
        fps: Frames per second for the video
        logger: Optional logger for debugging

    Example:
        >>> depth_frames = np.random.rand(100, 480, 640).astype(np.float32) * 0.5 + 0.3
        >>> save_depth_video(depth_frames, "depth.mp4")
    """
    import imageio

    logger = logger or log

    # Validate input
    if depth_frames.ndim != 3:
        raise ValueError(f"Expected 3D array (T, H, W), got shape {depth_frames.shape}")

    if depth_frames.dtype != np.float32:
        logger.warning(f"Depth frames are {depth_frames.dtype}, expected float32. Converting...")
        depth_frames = depth_frames.astype(np.float32)

    # Check depth range and warn about potential issues
    depth_min = float(depth_frames.min())
    depth_max = float(depth_frames.max())

    # Warn if depth values seem wrong (likely wrong units)
    if depth_min > 1.0 or depth_max > 10.0:
        logger.warning(
            f"Depth range [{depth_min:.2f}m, {depth_max:.2f}m] seems large. "
            "Are you sure values are in meters (not mm or cm)? "
            f"Encoding range is {DEPTH_MIN}m-{DEPTH_MAX}m for D405 camera."
        )

    # Warn if excessive invalid pixels will occur
    invalid_pixels = np.sum((depth_frames < DEPTH_MIN) | (depth_frames > DEPTH_MAX))
    total_pixels = depth_frames.size
    invalid_fraction = invalid_pixels / total_pixels

    if invalid_fraction > 0.3:
        logger.warning(
            f"Warning: {invalid_fraction * 100:.1f}% of pixels will be set to 0 (invalid) "
            f"(outside {DEPTH_MIN}m-{DEPTH_MAX}m range). "
            f"Actual range: [{depth_min:.3f}m, {depth_max:.3f}m]. "
            "Consider adjusting DEPTH_MIN/DEPTH_MAX or checking your depth data."
        )

    # Log depth statistics for debugging
    logger.debug(
        f"Saving depth video: {depth_frames.shape} frames, "
        f"range [{depth_min:.3f}m, {depth_max:.3f}m]"
    )

    # Encode each depth frame to RGB
    encoded_frames = []
    for frame in depth_frames:
        encoded_frames.append(encode_depth_to_rgb(frame))
    encoded_frames = np.array(encoded_frames)

    logger.debug(f"Encoded {len(encoded_frames)} depth frames to RGB: {encoded_frames.shape}")

    # Prepare video path
    video_path = Path(video_path)
    os.makedirs(video_path.parent, exist_ok=True)
    if video_path.suffix != ".mp4":
        video_path = video_path.with_suffix(".mp4")

    # Configure codec - ALL depth compression settings in one place
    codec_kwargs = {
        "codec": DEPTH_VIDEO_CODEC,
        "pixelformat": DEPTH_VIDEO_PIXELFORMAT,
        "output_params": ["-crf", DEPTH_VIDEO_CRF],
    }

    # Save video
    try:
        writer = imageio.get_writer(str(video_path), format="ffmpeg", fps=fps, **codec_kwargs)
        for frame in encoded_frames:
            writer.append_data(frame)
        writer.close()
        logger.debug(f"Saved depth video to {video_path}")
    except (ImportError, OSError, ValueError, RuntimeError) as e:
        logger.warning(f"FFmpeg writer failed ({type(e).__name__}: {e}), falling back to mimwrite")
        imageio.mimwrite(str(video_path), encoded_frames, format="mp4", fps=fps, **codec_kwargs)

validate_roundtrip_accuracy

validate_roundtrip_accuracy(depth_meters: ndarray, tolerance_mm: float = 0.1) -> dict

Validate that depth encoding/decoding roundtrip is accurate.

Parameters:

Name Type Description Default
depth_meters ndarray

(H, W) float32 array of depth values in meters

required
tolerance_mm float

Maximum acceptable error in millimeters

0.1

Returns:

Type Description
dict

Dictionary with validation results:

dict
  • max_error_mm: Maximum error in millimeters (valid pixels only)
dict
  • mean_error_mm: Mean error in millimeters (valid pixels only)
dict
  • passed: Whether the roundtrip is within tolerance
dict
  • invalid_preserved: Whether invalid pixels are correctly encoded as 0
dict
  • errors: (H, W) array of absolute errors in meters
Source code in molmo_spaces/utils/depth_utils.py
def validate_roundtrip_accuracy(depth_meters: np.ndarray, tolerance_mm: float = 0.1) -> dict:
    """Validate that depth encoding/decoding roundtrip is accurate.

    Args:
        depth_meters: (H, W) float32 array of depth values in meters
        tolerance_mm: Maximum acceptable error in millimeters

    Returns:
        Dictionary with validation results:
        - max_error_mm: Maximum error in millimeters (valid pixels only)
        - mean_error_mm: Mean error in millimeters (valid pixels only)
        - passed: Whether the roundtrip is within tolerance
        - invalid_preserved: Whether invalid pixels are correctly encoded as 0
        - errors: (H, W) array of absolute errors in meters
    """
    # Encode and decode
    encoded = encode_depth_to_rgb(depth_meters)
    decoded = decode_depth_from_rgb(encoded)

    # Check valid pixels (should be accurate within tolerance)
    valid_mask = (depth_meters >= DEPTH_MIN) & (depth_meters <= DEPTH_MAX)
    errors = np.abs(decoded - depth_meters)

    if np.sum(valid_mask) > 0:
        max_error_mm = float(np.max(errors[valid_mask])) * 1000.0
        mean_error_mm = float(np.mean(errors[valid_mask])) * 1000.0
    else:
        max_error_mm = 0.0
        mean_error_mm = 0.0

    # Check invalid pixels (should be 0.0 after roundtrip)
    invalid_mask = ~valid_mask
    invalid_preserved = np.all(decoded[invalid_mask] == 0.0) if np.sum(invalid_mask) > 0 else True

    return {
        "max_error_mm": max_error_mm,
        "mean_error_mm": mean_error_mm,
        "tolerance_mm": tolerance_mm,
        "passed": max_error_mm <= tolerance_mm,
        "invalid_preserved": invalid_preserved,
        "errors": errors,
    }

visualize_depth_error

visualize_depth_error(original_depth: ndarray, decoded_depth: ndarray, error: ndarray, title: str, save_path: Path | None = None)

Visualize the compression error between original and decoded depth.

Shows where errors occur (smooth regions vs edges) to understand compression behavior. Creates a 4-panel visualization: 1. Original depth 2. Decoded depth (after compression) 3. Edge detection (shows discontinuities) 4. Error heatmap

Parameters:

Name Type Description Default
original_depth ndarray

(H, W) float32 array of original depth in meters

required
decoded_depth ndarray

(H, W) float32 array of decoded depth in meters

required
error ndarray

(H, W) float32 array of absolute errors in meters

required
title str

Title for the visualization

required
save_path Path | None

Optional path to save the visualization (PNG)

None
Source code in molmo_spaces/utils/depth_utils.py
def visualize_depth_error(
    original_depth: np.ndarray,
    decoded_depth: np.ndarray,
    error: np.ndarray,
    title: str,
    save_path: Path | None = None,
):
    """Visualize the compression error between original and decoded depth.

    Shows where errors occur (smooth regions vs edges) to understand compression behavior.
    Creates a 4-panel visualization:
    1. Original depth
    2. Decoded depth (after compression)
    3. Edge detection (shows discontinuities)
    4. Error heatmap

    Args:
        original_depth: (H, W) float32 array of original depth in meters
        decoded_depth: (H, W) float32 array of decoded depth in meters
        error: (H, W) float32 array of absolute errors in meters
        title: Title for the visualization
        save_path: Optional path to save the visualization (PNG)
    """
    try:
        import matplotlib

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError as e:
        raise ImportError(
            "Visualization requires matplotlib. Install with: pip install matplotlib"
        ) from e

    # Create mask for valid pixels (within valid range in original)
    valid_mask = (original_depth >= DEPTH_MIN) & (original_depth <= DEPTH_MAX)
    num_valid = np.sum(valid_mask)
    num_invalid = np.sum(~valid_mask)

    # Detect edges (for analysis - shows where high errors are expected)
    edge_mask = detect_depth_edges(original_depth, gradient_threshold_mm=50.0)
    smooth_mask = valid_mask & ~edge_mask
    edge_valid_mask = valid_mask & edge_mask

    num_smooth = np.sum(smooth_mask)
    num_edges = np.sum(edge_valid_mask)

    # Convert error to mm
    error_mm = error * 1000.0

    # Create masked error array (NaN for invalid regions)
    error_mm_masked = error_mm.copy()
    error_mm_masked[~valid_mask] = np.nan

    # Create 4-panel visualization
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()

    # 1. Original depth
    ax = axes[0]
    im = ax.imshow(original_depth, cmap="turbo", vmin=DEPTH_MIN, vmax=DEPTH_MAX)
    ax.set_title("Original Depth")
    plt.colorbar(im, ax=ax, label="Depth (m)")

    # 2. Decoded depth
    ax = axes[1]
    im = ax.imshow(decoded_depth, cmap="turbo", vmin=DEPTH_MIN, vmax=DEPTH_MAX)
    ax.set_title("Decoded Depth (After MP4)")
    plt.colorbar(im, ax=ax, label="Depth (m)")

    # 3. Edge detection (shows where we expect higher errors)
    ax = axes[2]
    edge_vis = np.zeros_like(original_depth)
    edge_vis[edge_valid_mask] = 2  # Edges (red)
    edge_vis[smooth_mask] = 1  # Smooth regions (green)
    edge_vis[~valid_mask] = 0  # Invalid (blue)
    im = ax.imshow(edge_vis, cmap="RdYlGn_r", vmin=0, vmax=2)
    ax.set_title("Depth Discontinuities")
    cbar = plt.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(["Invalid", "Smooth", "Edge"])

    # 4. Error heatmap
    ax = axes[3]
    if num_valid > 0:
        max_error_display = np.nanmax(error_mm_masked)
    else:
        max_error_display = 5.0
    im = ax.imshow(error_mm_masked, cmap="hot", vmin=0, vmax=max_error_display)
    ax.set_title("Absolute Error (mm)")
    cbar = plt.colorbar(im, ax=ax, label="Error (mm)")
    ax.set_facecolor("lightgray")

    # Calculate statistics (separate smooth vs edges for analysis)
    if num_smooth > 0:
        smooth_errors = error_mm[smooth_mask]
        smooth_mean = np.mean(smooth_errors)
        smooth_p95 = np.percentile(smooth_errors, 95)
    else:
        smooth_mean = smooth_p95 = 0.0

    if num_edges > 0:
        edge_errors = error_mm[edge_valid_mask]
        edge_mean = np.mean(edge_errors)
        edge_p95 = np.percentile(edge_errors, 95)
    else:
        edge_mean = edge_p95 = 0.0

    if num_valid > 0:
        valid_errors = error_mm[valid_mask]
        mean_error = np.mean(valid_errors)
        p95_error = np.percentile(valid_errors, 95)
        max_error = np.max(valid_errors)

        stats_text = f"""Error Statistics:
Valid: {num_valid:,} ({num_valid / error.size * 100:.1f}%) = {num_smooth:,} smooth + {num_edges:,} edges
Invalid: {num_invalid:,} ({num_invalid / error.size * 100:.1f}%)

Smooth regions: mean={smooth_mean:.2f}mm, 95th={smooth_p95:.2f}mm
Edge regions: mean={edge_mean:.2f}mm, 95th={edge_p95:.2f}mm
Overall: mean={mean_error:.2f}mm, 95th={p95_error:.2f}mm, max={max_error:.2f}mm
"""
    else:
        stats_text = "No valid pixels found (all invalid)"

    fig.text(
        0.5,
        0.02,
        stats_text,
        ha="center",
        fontsize=10,
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
    )

    fig.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout(rect=[0, 0.15, 1, 0.96])

    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        log.debug(f"Saved error visualization to {save_path}")

    plt.close()

visualize_depth_image

visualize_depth_image(depth_meters: ndarray, title: str, save_path: Path | None = None)

Visualize depth image with statistics and save to debug file.

Creates a 4-panel visualization showing: 1. Raw depth with full range (0-2m) 2. Raw depth with encoding range 3. Valid/invalid pixel visualization (too close/valid/too far) 4. Encoded RGB representation

Parameters:

Name Type Description Default
depth_meters ndarray

(H, W) float32 array of depth values in meters

required
title str

Title for the visualization

required
save_path Path | None

Optional path to save the visualization (PNG)

None

Returns:

Type Description

Dictionary of depth statistics from compute_depth_encoding_stats()

Source code in molmo_spaces/utils/depth_utils.py
def visualize_depth_image(depth_meters: np.ndarray, title: str, save_path: Path | None = None):
    """Visualize depth image with statistics and save to debug file.

    Creates a 4-panel visualization showing:
    1. Raw depth with full range (0-2m)
    2. Raw depth with encoding range
    3. Valid/invalid pixel visualization (too close/valid/too far)
    4. Encoded RGB representation

    Args:
        depth_meters: (H, W) float32 array of depth values in meters
        title: Title for the visualization
        save_path: Optional path to save the visualization (PNG)

    Returns:
        Dictionary of depth statistics from compute_depth_encoding_stats()
    """
    try:
        import matplotlib

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError as e:
        raise ImportError(
            "Visualization requires matplotlib. Install with: pip install matplotlib"
        ) from e

    # Compute stats
    stats = compute_depth_encoding_stats(depth_meters)

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # 1. Raw depth with full range (0-2m)
    ax = axes[0, 0]
    im = ax.imshow(depth_meters, cmap="turbo", vmin=0, vmax=2.0)
    ax.set_title("Raw Depth (0-2m range)")
    plt.colorbar(im, ax=ax, label="Depth (m)")

    # 2. Raw depth with encoding range
    ax = axes[0, 1]
    im = ax.imshow(depth_meters, cmap="turbo", vmin=DEPTH_MIN, vmax=DEPTH_MAX)
    ax.set_title(f"Raw Depth ({DEPTH_MIN}m-{DEPTH_MAX}m encoding range)")
    plt.colorbar(im, ax=ax, label="Depth (m)")

    # 3. Valid/invalid visualization
    ax = axes[1, 0]
    validity_vis = np.ones_like(depth_meters)
    validity_vis[depth_meters < DEPTH_MIN] = 0  # Too close (red)
    validity_vis[depth_meters > DEPTH_MAX] = 2  # Too far (blue)
    im = ax.imshow(validity_vis, cmap="RdYlGn", vmin=0, vmax=2)
    ax.set_title("Validity (Red=too close→0, Green=valid, Blue=too far→0)")
    cbar = plt.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(["Too Close", "Valid", "Too Far"])

    # 4. Encoded RGB representation
    ax = axes[1, 1]
    encoded = encode_depth_to_rgb(depth_meters)
    ax.imshow(encoded)
    ax.set_title("Encoded as RGB (16-bit RG, 0=invalid)")

    # Add statistics text
    stats_text = f"""Statistics:
Min: {stats["min_depth"]:.3f}m  Max: {stats["max_depth"]:.3f}m  Mean: {stats["mean_depth"]:.3f}m
Invalid (→0): {stats["invalid_pixels"]:,} pixels ({stats["invalid_fraction"] * 100:.2f}%)
Precision: {stats["precision_mm"]:.4f}mm  Range: {DEPTH_MIN}m - {DEPTH_MAX}m
"""

    fig.text(
        0.5,
        0.02,
        stats_text,
        ha="center",
        fontsize=10,
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
    )

    fig.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout(rect=[0, 0.08, 1, 0.96])

    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        log.debug(f"Saved depth visualization to {save_path}")

    plt.close()

    return stats

devices

Modules:

Name Description
keyboard
spacemouse

Driver class for SpaceMouse controller. Modified based on the robosuite code.

keyboard

Classes:

Name Description
Keyboard
Keyboard
Keyboard()

Methods:

Name Description
on_press
on_release

Attributes:

Name Type Description
ANGULAR_SPEED
ARM_ROT_SPEED
LINEAR_SPEED
active_arm_index
active_robot
all_robot_arms
base_modes
key_states
key_states_lock
listener
num_robots
v_arm_rot
v_lin_x
v_lin_y
v_lin_z
v_yaw
Source code in molmo_spaces/utils/devices/keyboard.py
def __init__(self) -> None:
    self.LINEAR_SPEED = 0.14
    self.ANGULAR_SPEED = 3.14
    self.ARM_ROT_SPEED = 0.00628

    self.v_lin_x = 0.0
    self.v_lin_y = 0.0
    self.v_lin_z = 0.0
    self.v_yaw = 0.0
    self.v_arm_rot = 0.0

    self.key_states = {}
    self.key_states_lock = threading.Lock()

    self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release)
    self.listener.start()

    self.base_modes = {0: True}
    self.active_robot = 0
    self.active_arm_index = 0
    self.all_robot_arms = {0: [0]}
    self.num_robots = 1
ANGULAR_SPEED instance-attribute
ANGULAR_SPEED = 3.14
ARM_ROT_SPEED instance-attribute
ARM_ROT_SPEED = 0.00628
LINEAR_SPEED instance-attribute
LINEAR_SPEED = 0.14
active_arm_index instance-attribute
active_arm_index = 0
active_robot instance-attribute
active_robot = 0
all_robot_arms instance-attribute
all_robot_arms = {0: [0]}
base_modes instance-attribute
base_modes = {0: True}
key_states instance-attribute
key_states = {}
key_states_lock instance-attribute
key_states_lock = Lock()
listener instance-attribute
listener = Listener(on_press=on_press, on_release=on_release)
num_robots instance-attribute
num_robots = 1
v_arm_rot instance-attribute
v_arm_rot = 0.0
v_lin_x instance-attribute
v_lin_x = 0.0
v_lin_y instance-attribute
v_lin_y = 0.0
v_lin_z instance-attribute
v_lin_z = 0.0
v_yaw instance-attribute
v_yaw = 0.0
on_press
on_press(key) -> None
Source code in molmo_spaces/utils/devices/keyboard.py
def on_press(self, key) -> None:
    try:
        with self.key_states_lock:
            self.key_states[key] = True
    except AttributeError:
        pass
on_release
on_release(key) -> None
Source code in molmo_spaces/utils/devices/keyboard.py
def on_release(self, key) -> None:
    try:
        # controls for mobile base (only applicable if mobile base present)
        if key.char == "b":
            self.base_modes[self.active_robot] = not self.base_modes[
                self.active_robot
            ]  # toggle mobile base
        elif key.char == "s":
            self.active_arm_index = (self.active_arm_index + 1) % len(
                self.all_robot_arms[self.active_robot]
            )
        elif key.char == "=":
            self.active_robot = (self.active_robot + 1) % self.num_robots

    except AttributeError:
        pass

spacemouse

Driver class for SpaceMouse controller. Modified based on the robosuite code.

This class provides a driver support to SpaceMouse on Mac OS X. In particular, we assume you are using a SpaceMouse Wireless by default.

To set up a new SpaceMouse controller
  1. Download and install driver from https://www.3dconnexion.com/service/drivers.html
  2. Install hidapi library through pip (make sure you run uninstall hid first if it is installed).
  3. Make sure SpaceMouse is connected before running the script
  4. (Optional) Based on the model of SpaceMouse, you might need to change the vendor id and product id that correspond to the device.

For Linux support, you can find open-source Linux drivers and SDKs online. See http://spacenav.sourceforge.net/

Classes:

Name Description
SpaceMouse

A minimalistic driver class for SpaceMouse with HID library.

Functions:

Name Description
convert

Converts SpaceMouse message to commands.

nms_max_axis

Suppress all but the axis with the maximum |value|.

scale_to_control

Normalize raw HID readings to target range.

to_int16

Convert two 8 bit bytes to a signed 16 bit integer.

Attributes:

Name Type Description
AxisSpec
SPACE_MOUSE_SPEC
SPACE_MOUSE_WIRELESS_SPEC
space_mouse
AxisSpec module-attribute
AxisSpec = namedtuple('AxisSpec', ['channel', 'byte1', 'byte2', 'scale'])
SPACE_MOUSE_SPEC module-attribute
SPACE_MOUSE_SPEC = {'x': AxisSpec(channel=1, byte1=3, byte2=4, scale=-1), 'y': AxisSpec(channel=1, byte1=1, byte2=2, scale=-1), 'z': AxisSpec(channel=1, byte1=5, byte2=6, scale=-1), 'roll': AxisSpec(channel=1, byte1=5, byte2=6, scale=-1), 'pitch': AxisSpec(channel=1, byte1=3, byte2=4, scale=-1), 'yaw': AxisSpec(channel=1, byte1=1, byte2=2, scale=1)}
SPACE_MOUSE_WIRELESS_SPEC module-attribute
SPACE_MOUSE_WIRELESS_SPEC = {'x': AxisSpec(channel=1, byte1=1, byte2=2, scale=1), 'y': AxisSpec(channel=1, byte1=3, byte2=4, scale=-1), 'z': AxisSpec(channel=1, byte1=5, byte2=6, scale=-1), 'roll': AxisSpec(channel=1, byte1=7, byte2=8, scale=-1), 'pitch': AxisSpec(channel=1, byte1=9, byte2=10, scale=-1), 'yaw': AxisSpec(channel=1, byte1=11, byte2=12, scale=1)}
space_mouse module-attribute
space_mouse = SpaceMouse()
SpaceMouse
SpaceMouse(vendor_id=9583, product_id=50746, pos_sensitivity=1.0, rot_sensitivity=1.0)

A minimalistic driver class for SpaceMouse with HID library.

Note: Use hid.enumerate() to view all USB human interface devices (HID). Make sure SpaceMouse is detected before running the script. You can look up its vendor/product id from this method.

Parameters:

Name Type Description Default
env RobotEnv

The environment which contains the robot(s) to control using this device.

required
pos_sensitivity float

Magnitude of input position command scaling

1.0
rot_sensitivity float

Magnitude of scale input rotation commands scaling

1.0

Methods:

Name Description
get_controller_state

Grabs the current state of the 3D mouse.

rotation_matrix
run

Listener method that keeps pulling new messages.

start_control

Method that should be called externally before controller can

Attributes:

Name Type Description
control

Grabs current pose of Spacemouse

device
gripper
gripper_state
last_button_state
last_reset_button_state
pos_sensitivity
product_id
reset_button_state
rot_sensitivity
rotation
thread
vendor_id
Source code in molmo_spaces/utils/devices/spacemouse.py
def __init__(
    self,
    vendor_id=9583,
    product_id=50746,  # 50746 for wireless, 50741 for wire, 50770 for usb receiver
    pos_sensitivity=1.0,
    rot_sensitivity=1.0,
) -> None:
    print("Opening SpaceMouse device")
    self.vendor_id = vendor_id
    self.product_id = product_id
    self.device = hid.device()
    try:
        self.device.open(self.vendor_id, self.product_id)  # SpaceMouse
    except OSError as e:
        print("Failed to open SpaceMouse device cause: ", e)
        pass

    self.pos_sensitivity = pos_sensitivity
    self.rot_sensitivity = rot_sensitivity

    print(f"Manufacturer: {self.device.get_manufacturer_string()}")
    print(f"Product: {self.device.get_product_string()}")

    # 6-DOF variables
    self.x, self.y, self.z = 0, 0, 0
    self.roll, self.pitch, self.yaw = 0, 0, 0

    self.gripper = 0.0
    self.gripper_state = False  # Track gripper open/close state
    self.last_button_state = 0  # Track last button state

    self.reset_button_state = False
    self.last_reset_button_state = False

    self._control = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    self._reset_state = 0
    self.rotation = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
    self._enabled = False

    # launch a new listener thread to listen to SpaceMouse
    self.thread = threading.Thread(target=self.run)
    self.thread.daemon = True
    self.thread.start()
control property
control

Grabs current pose of Spacemouse

Returns:

Type Description

np.array: 6-DoF control value

device instance-attribute
device = device()
gripper instance-attribute
gripper = 0.0
gripper_state instance-attribute
gripper_state = False
last_button_state instance-attribute
last_button_state = 0
last_reset_button_state instance-attribute
last_reset_button_state = False
pos_sensitivity instance-attribute
pos_sensitivity = pos_sensitivity
product_id instance-attribute
product_id = product_id
reset_button_state instance-attribute
reset_button_state = False
rot_sensitivity instance-attribute
rot_sensitivity = rot_sensitivity
rotation instance-attribute
rotation = array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
thread instance-attribute
thread = Thread(target=run)
vendor_id instance-attribute
vendor_id = vendor_id
get_controller_state
get_controller_state()

Grabs the current state of the 3D mouse.

Returns:

Name Type Description
dict

A dictionary containing dpos, orn, unmodified orn, grasp, and reset

Source code in molmo_spaces/utils/devices/spacemouse.py
def get_controller_state(self):
    """
    Grabs the current state of the 3D mouse.

    Returns:
        dict: A dictionary containing dpos, orn, unmodified orn, grasp, and reset
    """
    dpos = self.control[:3] * 0.005 * self.pos_sensitivity
    roll, pitch, yaw = self.control[3:] * 0.005 * self.rot_sensitivity

    # convert RPY to an absolute orientation
    drot1 = self.rotation_matrix(angle=-pitch, direction=[1.0, 0, 0], point=None)[:3, :3]
    drot2 = self.rotation_matrix(angle=roll, direction=[0, 1.0, 0], point=None)[:3, :3]
    drot3 = self.rotation_matrix(angle=yaw, direction=[0, 0, 1.0], point=None)[:3, :3]

    self.rotation = self.rotation.dot(drot1.dot(drot2.dot(drot3)))

    return dict(
        dpos=dpos,
        raw_drotation=np.array([roll, pitch, yaw]),
        grasp=self.gripper,
        reset=self._reset_state,
    )
rotation_matrix
rotation_matrix(angle, direction, point=None)
Source code in molmo_spaces/utils/devices/spacemouse.py
def rotation_matrix(self, angle, direction, point=None):
    direction = np.array(direction) / np.linalg.norm(direction)
    return R.from_rotvec(angle * direction).as_matrix()
run
run() -> None

Listener method that keeps pulling new messages.

Source code in molmo_spaces/utils/devices/spacemouse.py
def run(self) -> None:
    """Listener method that keeps pulling new messages."""

    while True:
        d = self.device.read(13)
        # print(d)

        if d is not None and self._enabled:
            if self.product_id == 50741:
                ## logic for older spacemouse model

                if d[0] == 1:  ## readings from 6-DoF sensor
                    self.y = convert(d[1], d[2])
                    self.x = convert(d[3], d[4])
                    self.z = convert(d[5], d[6]) * -1.0

                elif d[0] == 2:
                    self.roll = convert(d[1], d[2])
                    self.pitch = convert(d[3], d[4])
                    self.yaw = convert(d[5], d[6])

                    self._control = [
                        self.x,
                        self.y,
                        self.z,
                        self.roll,
                        self.pitch,
                        self.yaw,
                    ]
            else:
                ## default logic for all other spacemouse models

                if d[0] == 1:  ## readings from 6-DoF sensor
                    self.y = convert(d[1], d[2]) * -1.0
                    self.x = convert(d[3], d[4]) * -1.0
                    self.z = convert(d[5], d[6]) * -1.0

                    self.roll = convert(d[7], d[8]) * -1.0
                    self.pitch = convert(d[9], d[10]) * -1.0
                    self.yaw = convert(d[11], d[12]) * -1.0

                    self._control = [
                        self.x,
                        self.y,
                        self.z,
                        self.roll,
                        self.pitch,
                        self.yaw,
                    ]

            if d[0] == 3:  ## readings from the side buttons
                # press left button
                if d[1] == 1:
                    if not self.last_button_state:  # Button just pressed
                        self.gripper_state = not self.gripper_state  # Toggle gripper state
                        self.gripper = 255.0 if self.gripper_state else 0.0

                # Update last button state
                self.last_button_state = d[1]

                self.reset_button_state = False
                if d[1] == 2:
                    if not self.reset_button_state:  # Button just pressed
                        self.reset_button_state = True

                self.last_reset_button_state = d[1]
start_control
start_control() -> None

Method that should be called externally before controller can start receiving commands.

Source code in molmo_spaces/utils/devices/spacemouse.py
def start_control(self) -> None:
    """
    Method that should be called externally before controller can
    start receiving commands.
    """
    self._reset_internal_state()
    self._reset_state = 0
    self._enabled = True
convert
convert(b1, b2)

Converts SpaceMouse message to commands.

Parameters:

Name Type Description Default
b1 int

8-bit byte

required
b2 int

8-bit byte

required

Returns:

Name Type Description
float

Scaled value from Spacemouse message

Source code in molmo_spaces/utils/devices/spacemouse.py
def convert(b1, b2):
    """
    Converts SpaceMouse message to commands.

    Args:
        b1 (int): 8-bit byte
        b2 (int): 8-bit byte

    Returns:
        float: Scaled value from Spacemouse message
    """
    return scale_to_control(to_int16(b1, b2))
nms_max_axis
nms_max_axis(control: ndarray, threshold=0.6)

Suppress all but the axis with the maximum |value|. The max axis is set to -1 or 1 based on sign, others are zeroed.

Parameters:

Name Type Description Default
control ndarray

6D input vector, assumed scaled in [-1, 1]

required
threshold float

minimum |value| to count as valid input

0.6

Returns:

Type Description

np.ndarray: filtered control vector with only max direction

Source code in molmo_spaces/utils/devices/spacemouse.py
def nms_max_axis(control: np.ndarray, threshold=0.6):
    """
    Suppress all but the axis with the maximum |value|.
    The max axis is set to -1 or 1 based on sign, others are zeroed.

    Args:
        control (np.ndarray): 6D input vector, assumed scaled in [-1, 1]
        threshold (float): minimum |value| to count as valid input

    Returns:
        np.ndarray: filtered control vector with only max direction
    """
    if np.all(np.abs(control) < threshold):
        return np.zeros_like(control)

    max_idx = np.argmax(np.abs(control))
    out = np.zeros_like(control)
    out[max_idx] = np.sign(control[max_idx])
    return np.array(out)
scale_to_control
scale_to_control(x, axis_scale=350.0, min_v=-1.0, max_v=1.0)

Normalize raw HID readings to target range.

Parameters:

Name Type Description Default
x int

Raw reading from HID

required
axis_scale float

(Inverted) scaling factor for mapping raw input value

350.0
min_v float

Minimum limit after scaling

-1.0
max_v float

Maximum limit after scaling

1.0

Returns:

Name Type Description
float

Clipped, scaled input from HID

Source code in molmo_spaces/utils/devices/spacemouse.py
def scale_to_control(x, axis_scale=350.0, min_v=-1.0, max_v=1.0):
    """
    Normalize raw HID readings to target range.

    Args:
        x (int): Raw reading from HID
        axis_scale (float): (Inverted) scaling factor for mapping raw input value
        min_v (float): Minimum limit after scaling
        max_v (float): Maximum limit after scaling

    Returns:
        float: Clipped, scaled input from HID
    """
    x = x / axis_scale
    x = min(max(x, min_v), max_v)
    return x
to_int16
to_int16(y1, y2)

Convert two 8 bit bytes to a signed 16 bit integer.

Parameters:

Name Type Description Default
y1 int

8-bit byte

required
y2 int

8-bit byte

required

Returns:

Name Type Description
int

16-bit integer

Source code in molmo_spaces/utils/devices/spacemouse.py
def to_int16(y1, y2):
    """
    Convert two 8 bit bytes to a signed 16 bit integer.

    Args:
        y1 (int): 8-bit byte
        y2 (int): 8-bit byte

    Returns:
        int: 16-bit integer
    """
    x = (y1) | (y2 << 8)
    if x >= 32768:
        x = -(65536 - x)
    return x

distance_transform_utils

Functions:

Name Description
cost_function
get_pixel_cost
get_segment_cost
make_discrete_path
make_distance_transform
make_grid_graph
simplify_path_greedy

cost_function

cost_function(x, weight_exp, zeroish=1e-30)
Source code in molmo_spaces/utils/distance_transform_utils.py
def cost_function(x, weight_exp, zeroish=1e-30):
    return np.power(np.maximum(x, zeroish), -weight_exp)

get_pixel_cost

get_pixel_cost(cost, p)
Source code in molmo_spaces/utils/distance_transform_utils.py
def get_pixel_cost(cost, p):
    return cost[round(p[0]), round(p[1])]

get_segment_cost

get_segment_cost(cost, p1, p2)
Source code in molmo_spaces/utils/distance_transform_utils.py
def get_segment_cost(cost, p1, p2):
    seg_cost = cost[line(round(p1[0]), round(p1[1]), round(p2[0]), round(p2[1]))]
    return seg_cost.sum(), seg_cost.max()

make_discrete_path

make_discrete_path(graph, source_row, source_col, target_row, target_col, distance_transform, weight_exp, grid_spacing, max_distance_to_obstacle)
Source code in molmo_spaces/utils/distance_transform_utils.py
def make_discrete_path(
    graph,
    source_row,
    source_col,
    target_row,
    target_col,
    distance_transform,
    weight_exp,
    grid_spacing,
    max_distance_to_obstacle,
):
    locs = nx.astar_path(graph, (source_row, source_col), (target_row, target_col))
    waypoints, path_cost = simplify_path_greedy(
        locs, distance_transform, weight_exp, grid_spacing, max_distance_to_obstacle
    )
    # print(f"{len(locs)} locs to {len(waypoints)} waypoints")
    return waypoints, locs, path_cost

make_distance_transform

make_distance_transform(grid, grid_spacing=None, max_distance_to_obstacle=0.75)
Source code in molmo_spaces/utils/distance_transform_utils.py
def make_distance_transform(grid, grid_spacing=None, max_distance_to_obstacle=0.75):
    pad = np.zeros((grid.shape[0] + 2, grid.shape[1] + 2), dtype=grid.dtype)
    pad[1:-1, 1:-1] = grid
    dt = distance_transform_edt(pad)[1:-1, 1:-1]
    if grid_spacing is None:
        return dt
    assert grid_spacing >= 0, f"grid_spacing should be positive, if given (got {grid_spacing})"
    return np.clip(dt, 0.0, max_distance_to_obstacle / grid_spacing)

make_grid_graph

make_grid_graph(grid, dt, weight_exp=2)
Source code in molmo_spaces/utils/distance_transform_utils.py
def make_grid_graph(
    grid,
    dt,
    weight_exp=2,
):
    def make_direction_edges(s0: tuple[slice, slice], s1: tuple[slice, slice]):
        locs = np.nonzero(grid[s0] * grid[s1])
        ws = np.maximum(cost[s0], cost[s1])[locs]
        labs0 = labels[s0][locs]
        labs1 = labels[s1][locs]
        return map(
            lambda labsw: (tuple(labsw[0]), tuple(labsw[1]), labsw[2]), zip(labs0, labs1, ws)
        )

    cost = cost_function(dt, weight_exp)
    labels = np.stack(
        np.meshgrid(range(grid.shape[0]), range(grid.shape[1]), indexing="ij"), axis=2
    )
    edges = chain(
        make_direction_edges((slice(None, -1), slice(None)), (slice(1, None), slice(None))),
        make_direction_edges((slice(None), slice(None, -1)), (slice(None), slice(1, None))),
    )
    g = nx.Graph()
    g.add_weighted_edges_from(edges)

    return g

simplify_path_greedy

simplify_path_greedy(waypoints, dt, weight_exp, grid_spacing, max_distance_to_obstacle)
Source code in molmo_spaces/utils/distance_transform_utils.py
def simplify_path_greedy(waypoints, dt, weight_exp, grid_spacing, max_distance_to_obstacle):
    if len(waypoints) == 0:
        return [], np.inf

    cost = cost_function(dt, weight_exp)
    if weight_exp == 0:
        # Make obstacles effectively non-visitable
        cost[dt == 0] = 1e30

    if len(waypoints) == 1:
        return [waypoints[0]], get_pixel_cost(cost, waypoints[0])

    if len(waypoints) == 2:
        return list(waypoints), get_segment_cost(cost, waypoints[0], waypoints[1])

    # If shortcutting a section of the path doesn't increase the cost, that's what we should do!
    #  obviously, 4-connectivity won't bring us there (8-connectivity should? - but expensive)

    acceptable_cost = cost_function(max_distance_to_obstacle / grid_spacing, weight_exp)

    res = [waypoints[0]]
    path_cost = 0.0

    def add_to_path():
        ret = cur_cost - get_pixel_cost(cost, res[-1])
        res.append(cur_next)
        return ret

    cur_next = waypoints[1]
    cur_cost, cur_peak = get_segment_cost(cost, res[-1], cur_next)
    for it in range(2, len(waypoints)):
        pos_next = waypoints[it]

        # We add extra cost to each last step in each segment, which is incorrect, but works okay
        pos_cost, pos_peak = get_segment_cost(cost, res[-1], pos_next)
        pos_local_cost, local_peak = get_segment_cost(cost, cur_next, pos_next)

        if pos_cost <= cur_cost + pos_local_cost and (
            pos_peak <= max(cur_peak, local_peak) or pos_peak < acceptable_cost
        ):
            cur_cost = pos_cost
            cur_peak = pos_peak
        else:
            path_cost += add_to_path()

            cur_cost = pos_local_cost
            cur_peak = local_peak

        cur_next = pos_next

    path_cost += add_to_path()

    return res, path_cost

eval_camera_randomization_utils

Level → value scaling for camera and light randomization.

Level is in [0, 100]: 0 = no randomization, 100 = maximum. Each parameter has an output range (min, max) and a mapping function that maps level to a value in that range.

Functions:

Name Description
add_eval_camera_args

Add eval camera CLI flags to an argparse parser.

apply_camera_perturbation

Sample a camera pose by perturbing the reference pose in spherical coordinates.

apply_camera_randomization_level

Return a copy of the camera system config with randomization params set via interpolation.

build_eval_camera_config_from_args

Build a FrankaEvalCameraSystem from parsed CLI args, or return None if not requested.

debug
derive_episode_camera_seed

Derive a deterministic seed for camera randomization from episode identity.

piecewise_linear

Piecewise linear interpolation.

resolve_reference_pose

Compute world-frame pos/forward/up from the reference body and return an updated copy.

setup_eval_cameras

Set up eval cameras: wrist via MJCF, exo via spherical perturbation.

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

add_eval_camera_args

add_eval_camera_args(parser: ArgumentParser) -> None

Add eval camera CLI flags to an argparse parser.

These flags are shared across all JSON eval entry points (standalone and distributed).

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def add_eval_camera_args(parser: argparse.ArgumentParser) -> None:
    """Add eval camera CLI flags to an argparse parser.

    These flags are shared across all JSON eval entry points (standalone and distributed).
    """
    group = parser.add_argument_group("Eval camera randomization")
    group.add_argument(
        "--use_eval_cameras",
        action="store_true",
        default=False,
        help="Use FrankaEvalCameraSystem with randomization instead of recorded cameras from JSON.",
    )
    group.add_argument(
        "--camera_rand_level",
        type=float,
        default=0.0,
        help="Camera randomization level (0-100). Only used with --use_eval_cameras.",
    )

apply_camera_perturbation

apply_camera_perturbation(cam: EvalExocentricCameraConfig, ref_forward: ndarray, ref_up: ndarray, workspace_center: ndarray, rng: RandomState) -> tuple[ndarray, ndarray, ndarray, float]

Sample a camera pose by perturbing the reference pose in spherical coordinates.

Orientation is computed via slerp between the calibrated rotation (from the shoulder-mount quaternion) and a lookat-at-workspace-center rotation, controlled by workspace_center_weight. At weight 0 the camera keeps its original orientation; at weight 1 it looks straight at the workspace center (plus optional noise).

Parameters:

Name Type Description Default
cam EvalExocentricCameraConfig

Resolved EvalExocentricCameraConfig (pos/forward/up already set).

required
ref_forward ndarray

Forward vector from the resolved quaternion-based reference pose.

required
ref_up ndarray

Up vector from the resolved quaternion-based reference pose.

required
workspace_center ndarray

3D point the camera should look at.

required
rng RandomState

Seeded random state for deterministic sampling.

required

Returns:

Type Description
tuple[ndarray, ndarray, ndarray, float]

(pos, forward, up, fov) in world frame.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def apply_camera_perturbation(
    cam: EvalExocentricCameraConfig,
    ref_forward: np.ndarray,
    ref_up: np.ndarray,
    workspace_center: np.ndarray,
    rng: np.random.RandomState,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]:
    """Sample a camera pose by perturbing the reference pose in spherical coordinates.

    Orientation is computed via slerp between the calibrated rotation (from the
    shoulder-mount quaternion) and a lookat-at-workspace-center rotation, controlled
    by ``workspace_center_weight``.  At weight 0 the camera keeps its original
    orientation; at weight 1 it looks straight at the workspace center (plus
    optional noise).

    Args:
        cam: Resolved EvalExocentricCameraConfig (pos/forward/up already set).
        ref_forward: Forward vector from the resolved quaternion-based reference pose.
        ref_up: Up vector from the resolved quaternion-based reference pose.
        workspace_center: 3D point the camera should look at.
        rng: Seeded random state for deterministic sampling.

    Returns:
        (pos, forward, up, fov) in world frame.
    """
    ref_pos = np.array(cam.pos, dtype=np.float64)
    ref_azimuth, ref_distance, ref_height = _decompose_to_spherical(ref_pos, workspace_center)

    azimuth_shift = 0
    if cam.azimuth_range is not None:
        azimuth_shift = rng.uniform(cam.azimuth_range[0], cam.azimuth_range[1])
    azimuth = ref_azimuth + azimuth_shift

    distance_shift = 0
    if cam.distance_range is not None:
        distance_shift = rng.uniform(cam.distance_range[0], cam.distance_range[1])
    distance = max(ref_distance + distance_shift, 0.10)

    height_shift = 0
    if cam.height_range is not None:
        height_shift = rng.uniform(cam.height_range[0], cam.height_range[1])
    height = ref_height + height_shift

    # Reconstruct Cartesian position
    pos = workspace_center.copy().astype(np.float64)
    pos[0] += distance * np.cos(azimuth)
    pos[1] += distance * np.sin(azimuth)
    pos[2] += height

    # Orientation: slerp between calibrated (ref) and lookat-at-workspace-center
    lookat_weight = 0.0
    if cam.workspace_center_weight is not None:
        lookat_weight = cam.workspace_center_weight

    # Build calibrated rotation from ref_forward / ref_up
    ref_fwd = np.array(ref_forward, dtype=np.float64)
    ref_fwd /= np.linalg.norm(ref_fwd) + 1e-12
    ref_u = np.array(ref_up, dtype=np.float64)
    ref_u /= np.linalg.norm(ref_u) + 1e-12
    ref_right = np.cross(ref_fwd, ref_u)
    ref_right /= np.linalg.norm(ref_right) + 1e-12
    # Re-orthogonalise up
    ref_u = np.cross(ref_right, ref_fwd)
    rot_calibrated = R.from_matrix(np.column_stack([ref_right, ref_u, -ref_fwd]))

    # Build lookat rotation toward workspace center (+ optional noise)
    lookat_target = workspace_center.copy().astype(np.float64)
    if cam.lookat_noise_range is not None:
        lookat_target += rng.uniform(cam.lookat_noise_range[0], cam.lookat_noise_range[1], size=3)

    la_fwd, la_up = compute_lookat_forward_up(pos, lookat_target)
    la_right = np.cross(la_fwd, la_up)
    rot_lookat = R.from_matrix(np.column_stack([la_right, la_up, -la_fwd]))

    # Slerp between calibrated and lookat orientations
    slerp = Slerp([0.0, 1.0], R.concatenate([rot_calibrated, rot_lookat]))
    rot_interp = slerp(lookat_weight)
    mat = rot_interp.as_matrix()
    forward = (-mat[:, 2]).astype(np.float32)
    up = mat[:, 1].astype(np.float32)

    # FOV
    fov_shift = 0
    fov = cam.fov
    if cam.fov_range is not None:
        fov_shift = rng.uniform(cam.fov_range[0], cam.fov_range[1]) - fov
    fov += fov_shift

    log.info(
        "[EVAL CAMERA] Perturbation info:\n"
        f"  azimuth: {azimuth_shift} {cam.azimuth_range}\n"
        f"  distance: {distance_shift} {cam.distance_range}\n"
        f"  height: {height_shift} {cam.height_range}\n"
        f"  lookat_weight: {lookat_weight} noise_range={cam.lookat_noise_range}\n"
        f"  fov: {fov_shift} {cam.fov_range}"
    )

    return pos.astype(np.float32), forward, up, float(fov)

apply_camera_randomization_level

apply_camera_randomization_level(camera_config: FrankaEvalCameraSystem, level: float) -> FrankaEvalCameraSystem

Return a copy of the camera system config with randomization params set via interpolation.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def apply_camera_randomization_level(
    camera_config: FrankaEvalCameraSystem, level: float
) -> FrankaEvalCameraSystem:
    """Return a copy of the camera system config with randomization params set via interpolation."""
    new_cameras: list[Any] = []
    for cam in camera_config.cameras:
        randomizable = camera_config.ref_level_ranges[0][1][cam.name]
        if not randomizable:
            new_cameras.append(cam)
            continue

        cam_dict = cam.model_dump()
        for param in randomizable:
            cam_dict[param] = _reshape_to_original(
                _get_camera_param_at_level_from_config(camera_config, cam.name, param, level),
                randomizable[param],
            )
        new_cameras.append(cam.__class__.model_validate(cam_dict))
    return camera_config.model_copy(update={"cameras": new_cameras})

build_eval_camera_config_from_args

build_eval_camera_config_from_args(args: Namespace) -> FrankaEvalCameraSystem | None

Build a FrankaEvalCameraSystem from parsed CLI args, or return None if not requested.

Returns None if --use_eval_cameras was not passed. Otherwise, creates the eval camera system with the requested camera subset and randomization level applied.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def build_eval_camera_config_from_args(
    args: argparse.Namespace,
) -> FrankaEvalCameraSystem | None:
    """Build a FrankaEvalCameraSystem from parsed CLI args, or return None if not requested.

    Returns None if --use_eval_cameras was not passed. Otherwise, creates the eval camera
    system with the requested camera subset and randomization level applied.
    """
    if not args.use_eval_cameras:
        return None

    from molmo_spaces.configs.camera_configs import FrankaEvalCameraSystem

    print(f"Using camera randomization level {args.camera_rand_level}")
    return apply_camera_randomization_level(FrankaEvalCameraSystem(), args.camera_rand_level)

debug

debug()
Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def debug():
    import json

    from molmo_spaces.configs.camera_configs import FrankaEvalCameraSystem

    def print_randomizable_values_at_levels(levels=(0, 10, 25, 50, 75, 90, 100)) -> None:
        config = FrankaEvalCameraSystem()
        for cam in config.cameras:
            params = config.ref_level_ranges[0][1].get(cam.name, {})
            if not params:
                continue
            for param in sorted(params.keys()):
                print(f"{cam.name} {param}:")
                for level in levels:
                    val = _get_camera_param_at_level_from_config(config, cam.name, param, level)
                    print(f"level {level:3.0f}: {val}")

    # print_randomizable_values_at_levels()

    # new_config = apply_camera_randomization_level(FrankaEvalCameraSystem(), level=33)
    # print(json.dumps(new_config.model_dump(), indent=2))

    def test_camera_from_args():
        from argparse import ArgumentParser

        parser = ArgumentParser()
        add_eval_camera_args(parser)
        args = parser.parse_args()
        args.use_eval_cameras = True
        args.camera_rand_level = 33
        args.no_gopro = True
        args.num_zeds = 1
        args.disable_shoulder = True

        num_cameras = 5
        if args.no_gopro:
            num_cameras -= 1
        if args.num_zeds < 2:
            num_cameras -= 2 - args.num_zeds
        if args.disable_shoulder:
            num_cameras -= 1

        new_config = build_eval_camera_config_from_args(args)
        assert len(new_config.cameras) == num_cameras

        print(json.dumps(new_config.model_dump(), indent=2))

    test_camera_from_args()

derive_episode_camera_seed

derive_episode_camera_seed(episode_spec: Any) -> int

Derive a deterministic seed for camera randomization from episode identity.

The seed is a hash of fields that uniquely identify an episode so that the same (episode, level) pair always produces the same camera placement.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def derive_episode_camera_seed(episode_spec: Any) -> int:
    """Derive a deterministic seed for camera randomization from episode identity.

    The seed is a hash of fields that uniquely identify an episode so that the
    same (episode, level) pair always produces the same camera placement.
    """
    import hashlib

    parts = [
        str(getattr(episode_spec, "scene_dataset", "")),
        str(getattr(episode_spec, "data_split", "")),
        str(getattr(episode_spec, "house_index", 0)),
    ]
    source = getattr(episode_spec, "source", None)
    if source is not None:
        parts.append(str(getattr(source, "h5_file", "")))
        parts.append(str(getattr(source, "traj_key", "")))
    seed_val = getattr(episode_spec, "seed", None)
    if seed_val is not None:
        parts.append(str(seed_val))

    key = "|".join(parts).encode()
    return int(hashlib.sha256(key).hexdigest()[:8], 16)

piecewise_linear

piecewise_linear(level: float, breakpoints: list[float], values: list[float]) -> float

Piecewise linear interpolation.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def piecewise_linear(level: float, breakpoints: list[float], values: list[float]) -> float:
    """Piecewise linear interpolation."""
    if level <= breakpoints[0]:
        return values[0]
    if level >= breakpoints[-1]:
        return values[-1]

    for bp_low, bp_high, val_low, val_high in zip(breakpoints, breakpoints[1:], values, values[1:]):
        if bp_low <= level <= bp_high:
            span = bp_high - bp_low
            return val_low + (val_high - val_low) * ((level - bp_low) / span)

    return values[-1]  # fallback, should never reach

resolve_reference_pose

resolve_reference_pose(cam_config, env) -> EvalExocentricCameraConfig

Compute world-frame pos/forward/up from the reference body and return an updated copy.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def resolve_reference_pose(cam_config, env) -> EvalExocentricCameraConfig:
    """Compute world-frame pos/forward/up from the reference body and return an updated copy."""
    pos, forward, up = env.camera_manager.create_quaternion_camera_pose(
        env,
        reference_body_name=cam_config.reference_body_names[0],
        camera_offset=np.array(cam_config.camera_offset, dtype=np.float32),
        camera_quaternion=np.array(cam_config.camera_quaternion, dtype=np.float32),
    )
    return cam_config.model_copy(
        update={
            "pos": pos.tolist(),
            "forward": forward.tolist(),
            "up": up.tolist(),
        }
    )

setup_eval_cameras

setup_eval_cameras(env: CPUMujocoEnv, eval_system: FrankaEvalCameraSystem, task_relevant_bodies: list[str], workspace_center: ndarray, rng_seed: int) -> None

Set up eval cameras: wrist via MJCF, exo via spherical perturbation.

For each camera in eval_system:

  • Wrist (MjcfCameraConfig): placed directly, no visibility check.
  • Exo (EvalExocentricCameraConfig): reference pose is resolved from the shoulder mount, then apply_camera_perturbation samples a pose in spherical coords around the workspace center. Multiple attempts are made to satisfy visibility constraints; if all fail a CameraPlacementError is raised.

Parameters:

Name Type Description Default
env CPUMujocoEnv

CPUMujocoEnv with the scene already set up.

required
eval_system FrankaEvalCameraSystem

FrankaEvalCameraSystem with randomization ranges already interpolated for the desired level.

required
task_relevant_bodies list[str]

Body names to check visibility against.

required
workspace_center ndarray

3D centroid of task-relevant objects.

required
rng_seed int

Deterministic seed for repeatable placement.

required

Raises:

Type Description
CameraPlacementError

If an exo camera cannot be placed with visibility constraints after max_placement_attempts.

Source code in molmo_spaces/utils/eval_camera_randomization_utils.py
def setup_eval_cameras(
    env: CPUMujocoEnv,
    eval_system: FrankaEvalCameraSystem,
    task_relevant_bodies: list[str],
    workspace_center: np.ndarray,
    rng_seed: int,
) -> None:
    """Set up eval cameras: wrist via MJCF, exo via spherical perturbation.

    For each camera in *eval_system*:

    - **Wrist** (``MjcfCameraConfig``): placed directly, no visibility check.
    - **Exo** (``EvalExocentricCameraConfig``): reference pose is resolved
      from the shoulder mount, then ``apply_camera_perturbation`` samples
      a pose in spherical coords around the workspace center.  Multiple
      attempts are made to satisfy visibility constraints; if all fail a
      ``CameraPlacementError`` is raised.

    Args:
        env: CPUMujocoEnv with the scene already set up.
        eval_system: FrankaEvalCameraSystem with randomization ranges
            already interpolated for the desired level.
        task_relevant_bodies: Body names to check visibility against.
        workspace_center: 3D centroid of task-relevant objects.
        rng_seed: Deterministic seed for repeatable placement.

    Raises:
        CameraPlacementError: If an exo camera cannot be placed with
            visibility constraints after ``max_placement_attempts``.
    """
    from molmo_spaces.configs.camera_configs import (
        EvalExocentricCameraConfig,
        MjcfCameraConfig,
    )
    from molmo_spaces.tasks.task_sampler_errors import CameraPlacementError

    camera_manager = env.camera_manager
    rng = np.random.RandomState(rng_seed)

    for cam in eval_system.cameras:
        # --- Wrist cameras: register clean, then apply noise with seeded rng ---
        if isinstance(cam, MjcfCameraConfig):
            clean_cam = cam.model_copy(
                update={
                    "pos_noise_range": None,
                    "orientation_noise_degrees": None,
                    "fov_noise_degrees": None,
                }
            )
            camera_manager._setup_mjcf_camera(env, clean_cam)
            _apply_mjcf_camera_noise(env, cam, rng)
            log.info(f"[EVAL CAMERA] '{cam.name}' placed (wrist)")
            continue

        # --- Exo cameras: spherical perturbation ---
        if not isinstance(cam, EvalExocentricCameraConfig):
            log.warning(f"[EVAL CAMERA] Unknown camera type for '{cam.name}': {type(cam).__name__}")
            continue

        resolved = resolve_reference_pose(cam, env)
        ref_forward = np.array(resolved.forward, dtype=np.float32)
        ref_up = np.array(resolved.up, dtype=np.float32)
        log.info(
            f"[EVAL CAMERA] '{cam.name}' reference pose: pos={np.round(resolved.pos, 3).tolist()}"
        )

        max_attempts = resolved.max_placement_attempts
        has_visibility = bool(task_relevant_bodies)

        for attempt in range(max_attempts):
            pos, forward, up, fov = apply_camera_perturbation(
                resolved,
                ref_forward,
                ref_up,
                workspace_center,
                rng,
            )
            camera_manager.add_camera(cam.name, pos, forward, up, fov)

            if not has_visibility or _check_camera_visibility(env, cam.name, task_relevant_bodies):
                log.info(
                    f"[EVAL CAMERA] '{cam.name}' placed (attempt {attempt + 1}/{max_attempts})"
                )
                break

            # Remove failed attempt before retrying
            if cam.name in camera_manager.registry.cameras:
                del camera_manager.registry.cameras[cam.name]
        else:
            raise CameraPlacementError(
                f"Failed to place eval camera '{cam.name}' with visibility of "
                f"{task_relevant_bodies} after {max_attempts} attempts"
            )

eval_utils

Evaluation utilities for logging stats and videos to wandb.

Classes:

Name Description
EpisodeResult

Result from a single evaluation episode.

Functions:

Name Description
collect_episode_results

Scan output directory for HDF5 files and extract episode results.

compose_episode_videos

Compose videos from multiple cameras for each episode.

compose_videos_side_by_side

Compose multiple videos side-by-side into a single video.

compute_eval_stats

Compute aggregate statistics from evaluation results.

create_video_results_table

Create and log a WandB table with videos and episode metadata.

load_video_frames

Load frames from a video file.

log_eval_results_to_wandb

Log evaluation results and composed videos to wandb.

log_eval_videos_to_wandb

Find and log evaluation videos to wandb.

parse_obs_scene

Parse obs_scene from HDF5 dataset.

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

EpisodeResult dataclass

EpisodeResult(episode_idx: int, house_id: int | str, success: bool, num_steps: int, task_description: str | None = None, object_name: str | None = None, seed: int | None = None, data_file_path: Path | None = None, oracle_done: bool | None = None, metadata: dict[str, Any] = dict())

Result from a single evaluation episode.

Attributes:

Name Type Description
episode_idx int

Index of the episode within its house.

house_id int | str

House identifier (int or str like "house_5").

success bool

Whether the episode was successful (at end of episode).

num_steps int

Number of steps taken in the episode.

task_description str | None

Natural language task description.

object_name str | None

Name of the target object (if applicable).

seed int | None

Random seed used for the episode.

data_file_path Path | None

Path to the HDF5 file containing this episode's data. Use this together with episode_idx to uniquely identify an episode, especially when there are multiple batches per house.

oracle_done bool | None

Whether success was achieved at ANY point during the episode.

metadata dict[str, Any]

Additional metadata about the episode.

data_file_path class-attribute instance-attribute
data_file_path: Path | None = None
episode_idx instance-attribute
episode_idx: int
house_id instance-attribute
house_id: int | str
metadata class-attribute instance-attribute
metadata: dict[str, Any] = field(default_factory=dict)
num_steps instance-attribute
num_steps: int
object_name class-attribute instance-attribute
object_name: str | None = None
oracle_done class-attribute instance-attribute
oracle_done: bool | None = None
seed class-attribute instance-attribute
seed: int | None = None
success instance-attribute
success: bool
task_description class-attribute instance-attribute
task_description: str | None = None

collect_episode_results

collect_episode_results(output_dir: Path) -> list[EpisodeResult]

Scan output directory for HDF5 files and extract episode results.

Parameters:

Name Type Description Default
output_dir Path

Directory containing evaluation output

required

Returns:

Type Description
list[EpisodeResult]

List of EpisodeResult objects. Each result includes data_file_path

list[EpisodeResult]

to uniquely identify the episode even when there are multiple batches

list[EpisodeResult]

per house.

Source code in molmo_spaces/utils/eval_utils.py
def collect_episode_results(output_dir: Path) -> list[EpisodeResult]:
    """Scan output directory for HDF5 files and extract episode results.

    Args:
        output_dir: Directory containing evaluation output

    Returns:
        List of EpisodeResult objects. Each result includes data_file_path
        to uniquely identify the episode even when there are multiple batches
        per house.
    """
    results = []

    # Find all house directories
    for house_dir in sorted(output_dir.iterdir()):
        if not house_dir.is_dir() or not house_dir.name.startswith("house_"):
            continue

        house_id = house_dir.name

        # Find HDF5 files in this house directory
        for hdf5_path in sorted(house_dir.glob("trajectories*.h5")):
            with h5py.File(hdf5_path, "r") as f:
                for traj_key in sorted(f.keys()):
                    if not traj_key.startswith("traj_"):
                        continue

                    traj_group = f[traj_key]
                    episode_idx = int(traj_key.split("_")[1])

                    # Extract success (at end) and oracle_done (at any point)
                    success = False
                    oracle_done = False
                    if "success" in traj_group:
                        success_array = np.array(traj_group["success"])
                        if len(success_array) > 0:
                            success = bool(success_array[-1])
                            oracle_done = bool(np.any(success_array))

                    # Extract episode length from first action sub-dataset
                    num_steps = 0
                    if "actions" in traj_group:
                        actions_group = traj_group["actions"]
                        for action_key in actions_group:
                            num_steps = len(actions_group[action_key]) - 2
                            break

                    # Extract task description from obs_scene
                    task_description = None
                    object_name = None
                    if "obs_scene" in traj_group:
                        obs_scene = parse_obs_scene(traj_group["obs_scene"][()])
                        task_description = obs_scene.get("task_description") or obs_scene.get(
                            "text"
                        )
                        object_name = obs_scene.get("object_name")

                    results.append(
                        EpisodeResult(
                            episode_idx=episode_idx,
                            house_id=house_id,
                            success=success,
                            num_steps=num_steps,
                            task_description=task_description,
                            object_name=object_name,
                            data_file_path=hdf5_path,
                            oracle_done=oracle_done,
                        )
                    )

    return results

compose_episode_videos

compose_episode_videos(eval_dir: Path, camera_names: list[str], output_dir: Path | None = None, success_status: dict[str, bool] | None = None) -> dict[str, Path]

Compose videos from multiple cameras for each episode.

Parameters:

Name Type Description Default
eval_dir Path

Directory containing evaluation videos

required
camera_names list[str]

List of camera names to compose

required
output_dir Path | None

Directory to save composed videos (defaults to eval_dir/composed)

None
success_status dict[str, bool] | None

Optional dict mapping episode keys to success status

None

Returns:

Type Description
dict[str, Path]

Dict mapping episode keys to composed video paths

Source code in molmo_spaces/utils/eval_utils.py
def compose_episode_videos(
    eval_dir: Path,
    camera_names: list[str],
    output_dir: Path | None = None,
    success_status: dict[str, bool] | None = None,
) -> dict[str, Path]:
    """Compose videos from multiple cameras for each episode.

    Args:
        eval_dir: Directory containing evaluation videos
        camera_names: List of camera names to compose
        output_dir: Directory to save composed videos (defaults to eval_dir/composed)
        success_status: Optional dict mapping episode keys to success status

    Returns:
        Dict mapping episode keys to composed video paths
    """
    if output_dir is None:
        output_dir = eval_dir / "composed"

    # Find all episodes by scanning for video files
    episode_videos = defaultdict(dict)

    for cam_name in camera_names:
        for video_path in eval_dir.glob(f"**/episode_*_{cam_name}*.mp4"):
            # Extract episode key from path
            house_dir = video_path.parent.name
            # Parse episode index from filename
            match = re.match(r"episode_(\d+)_", video_path.name)
            if match:
                episode_idx = int(match.group(1))
                episode_key = f"{house_dir}/episode_{episode_idx:08d}"
                episode_videos[episode_key][cam_name] = video_path

    composed_paths = {}
    for episode_key, cam_paths in sorted(episode_videos.items()):
        # Only compose if we have all cameras
        if len(cam_paths) < len(camera_names):
            continue

        # Order cameras consistently
        video_paths = [cam_paths[cam] for cam in camera_names if cam in cam_paths]

        # Determine success suffix for filename
        success_suffix = ""
        if success_status and episode_key in success_status:
            success_suffix = "_success" if success_status[episode_key] else "_failed"

        output_path = output_dir / f"{episode_key.replace('/', '_')}_composed{success_suffix}.mp4"
        result = compose_videos_side_by_side(video_paths, output_path)
        if result:
            composed_paths[episode_key] = result

    return composed_paths

compose_videos_side_by_side

compose_videos_side_by_side(video_paths: list[Path], output_path: Path, target_height: int = 368, target_width: int = 1280) -> Path | None

Compose multiple videos side-by-side into a single video.

Parameters:

Name Type Description Default
video_paths list[Path]

List of paths to input videos

required
output_path Path

Path for the output composed video

required
target_height int

Target height for the output video

368
target_width int

Target width for the output video

1280

Returns:

Type Description
Path | None

Path to the composed video, or None if failed

Source code in molmo_spaces/utils/eval_utils.py
def compose_videos_side_by_side(
    video_paths: list[Path],
    output_path: Path,
    target_height: int = 368,
    target_width: int = 1280,
) -> Path | None:
    """Compose multiple videos side-by-side into a single video.

    Args:
        video_paths: List of paths to input videos
        output_path: Path for the output composed video
        target_height: Target height for the output video
        target_width: Target width for the output video

    Returns:
        Path to the composed video, or None if failed
    """
    import cv2

    from molmo_spaces.utils.save_utils import save_frames_to_mp4

    if not video_paths:
        return None

    # Load all videos
    all_frames = []
    fps = None
    for vp in video_paths:
        if not vp.exists():
            log.warning(f"Video not found: {vp}")
            return None
        frames, video_fps = load_video_frames(vp)
        if not frames:
            log.warning(f"No frames in video: {vp}")
            return None
        all_frames.append(frames)
        if fps is None:
            fps = video_fps

    # Find the minimum number of frames across all videos
    min_frames = min(len(frames) for frames in all_frames)

    # Truncate all videos to the same length
    all_frames = [frames[:min_frames] for frames in all_frames]

    # Resize each video to fit within the target dimensions split equally
    n_videos = len(all_frames)
    per_video_width = target_width // n_videos
    resized_frames = []
    for frames in all_frames:
        h, w = frames[0].shape[:2]
        # Scale to fit target_height, then clamp width to per_video_width
        scale = min(target_height / h, per_video_width / w)
        new_h = int(h * scale)
        new_w = int(w * scale)
        resized = [cv2.resize(f, (new_w, new_h)) for f in frames]
        resized_frames.append(resized)
    all_frames = resized_frames

    # Compose frames side-by-side, padding to exact target dimensions
    composed_frames = []
    for frame_idx in range(min_frames):
        row_frames = [frames[frame_idx] for frames in all_frames]
        composed = np.concatenate(row_frames, axis=1)
        # Pad to exact target size if needed
        ch, cw = composed.shape[:2]
        if ch != target_height or cw != target_width:
            padded = np.zeros((target_height, target_width, 3), dtype=composed.dtype)
            padded[:ch, :cw] = composed
            composed = padded
        composed_frames.append(composed)

    # Save composed video using save_frames_to_mp4 (same as rest of codebase)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    stacked_array = np.array(composed_frames)

    # Ensure uint8 format
    if stacked_array.dtype != np.uint8:
        if stacked_array.max() <= 1.0:
            stacked_array = (stacked_array * 255).astype(np.uint8)
        else:
            stacked_array = stacked_array.astype(np.uint8)

    save_frames_to_mp4(stacked_array, str(output_path), fps=fps or 30.0)

    return output_path

compute_eval_stats

compute_eval_stats(results: list[EpisodeResult]) -> dict[str, Any]

Compute aggregate statistics from evaluation results.

Parameters:

Name Type Description Default
results list[EpisodeResult]

List of episode results

required

Returns:

Type Description
dict[str, Any]

Dict of aggregate statistics

Source code in molmo_spaces/utils/eval_utils.py
def compute_eval_stats(results: list[EpisodeResult]) -> dict[str, Any]:
    """Compute aggregate statistics from evaluation results.

    Args:
        results: List of episode results

    Returns:
        Dict of aggregate statistics
    """
    if not results:
        return {}

    successes = [r.success for r in results]
    oracle_dones = [r.oracle_done for r in results if r.oracle_done is not None]
    num_steps = [r.num_steps for r in results]

    # Per-house stats
    house_results = defaultdict(list)
    for r in results:
        house_results[r.house_id].append(r.success)

    house_success_rates = {h: sum(s) / len(s) for h, s in house_results.items()}

    stats = {
        "total_episodes": len(results),
        "success_count": sum(successes),
        "failure_count": len(successes) - sum(successes),
        "success_rate": sum(successes) / len(successes) if successes else 0.0,
        "avg_episode_length": sum(num_steps) / len(num_steps) if num_steps else 0.0,
        "min_episode_length": min(num_steps) if num_steps else 0,
        "max_episode_length": max(num_steps) if num_steps else 0,
        "num_houses": len(house_results),
        "house_success_rates": house_success_rates,
    }

    # Oracle done stats (success at any point during episode)
    if oracle_dones:
        stats["oracle_done_count"] = sum(oracle_dones)
        stats["oracle_done_rate"] = sum(oracle_dones) / len(oracle_dones)

    # Successful episode stats
    successful_steps = [r.num_steps for r in results if r.success]
    if successful_steps:
        stats["avg_successful_episode_length"] = sum(successful_steps) / len(successful_steps)

    return stats

create_video_results_table

create_video_results_table(episode_data: list[dict], table_name: str = 'eval/video_results') -> None

Create and log a WandB table with videos and episode metadata.

This is a shared utility for both distributed and non-distributed evaluation. Each dict in episode_data should contain: - video_path: Path to the video file (required) - task_description: Natural language task description - object_name: Target object name - house_id: House identifier - episode_idx: Episode index - num_steps: Number of steps taken - success: Boolean success status (at end of episode) - oracle_done: Boolean, success at ANY point during episode (optional) - source_episode_path: Original episode path (optional, for provenance)

Parameters:

Name Type Description Default
episode_data list[dict]

List of dicts with video paths and metadata

required
table_name str

Name for the WandB table (default: "eval/video_results")

'eval/video_results'
Source code in molmo_spaces/utils/eval_utils.py
def create_video_results_table(
    episode_data: list[dict],
    table_name: str = "eval/video_results",
) -> None:
    """Create and log a WandB table with videos and episode metadata.

    This is a shared utility for both distributed and non-distributed evaluation.
    Each dict in episode_data should contain:
        - video_path: Path to the video file (required)
        - task_description: Natural language task description
        - object_name: Target object name
        - house_id: House identifier
        - episode_idx: Episode index
        - num_steps: Number of steps taken
        - success: Boolean success status (at end of episode)
        - oracle_done: Boolean, success at ANY point during episode (optional)
        - source_episode_path: Original episode path (optional, for provenance)

    Args:
        episode_data: List of dicts with video paths and metadata
        table_name: Name for the WandB table (default: "eval/video_results")
    """
    import wandb

    table_data = []
    for ep in episode_data:
        video_path = ep.get("video_path")
        if video_path is None or not Path(video_path).exists():
            continue

        row = [
            wandb.Video(str(video_path), format="mp4"),
            ep.get("task_description", ""),
            ep.get("object_name", ""),
            str(ep.get("house_id", "")),
            ep.get("episode_idx", 0),
            ep.get("num_steps", 0),
            "Success" if ep.get("success") else "Failed",
            "Yes" if ep.get("oracle_done") else "No",
            ep.get("source_episode_path", ""),
        ]
        table_data.append(row)

    if table_data:
        video_table = wandb.Table(
            data=table_data,
            columns=[
                "video",
                "task_description",
                "object_name",
                "house_id",
                "episode_idx",
                "num_steps",
                "result",
                "oracle_done",
                "source_episode_path",
            ],
        )
        wandb.log({table_name: video_table})
        log.info(f"Uploaded {table_name} with {len(table_data)} episodes")

load_video_frames

load_video_frames(video_path: Path) -> tuple[list[ndarray], float]

Load frames from a video file.

Parameters:

Name Type Description Default
video_path Path

Path to the video file

required

Returns:

Type Description
tuple[list[ndarray], float]

Tuple of (list of frames as numpy arrays in RGB format, fps)

Source code in molmo_spaces/utils/eval_utils.py
def load_video_frames(video_path: Path) -> tuple[list[np.ndarray], float]:
    """Load frames from a video file.

    Args:
        video_path: Path to the video file

    Returns:
        Tuple of (list of frames as numpy arrays in RGB format, fps)
    """
    import cv2

    cap = cv2.VideoCapture(str(video_path))
    fps = cap.get(cv2.CAP_PROP_FPS)

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)

    cap.release()
    return frames, fps

log_eval_results_to_wandb

log_eval_results_to_wandb(results: list[EpisodeResult], composed_videos: dict[str, Path] | None = None) -> None

Log evaluation results and composed videos to wandb.

Creates a video table with composed videos in the first column and metadata (task description, episode length, success/fail, etc.) in subsequent columns.

Parameters:

Name Type Description Default
results list[EpisodeResult]

List of episode results

required
composed_videos dict[str, Path] | None

Optional dict mapping episode keys to composed video paths

None
Source code in molmo_spaces/utils/eval_utils.py
def log_eval_results_to_wandb(
    results: list[EpisodeResult],
    composed_videos: dict[str, Path] | None = None,
) -> None:
    """Log evaluation results and composed videos to wandb.

    Creates a video table with composed videos in the first column and metadata
    (task description, episode length, success/fail, etc.) in subsequent columns.

    Args:
        results: List of episode results
        composed_videos: Optional dict mapping episode keys to composed video paths
    """
    import wandb

    # Compute and log stats
    stats = compute_eval_stats(results)

    # Log scalar metrics as summary values (not time-series)
    # These are final metrics for a single checkpoint evaluation
    wandb.summary["eval/total_episodes"] = stats.get("total_episodes", 0)
    wandb.summary["eval/success_count"] = stats.get("success_count", 0)
    wandb.summary["eval/failure_count"] = stats.get("failure_count", 0)
    wandb.summary["eval/success_rate"] = stats.get("success_rate", 0.0)
    wandb.summary["eval/avg_episode_length"] = stats.get("avg_episode_length", 0.0)
    wandb.summary["eval/min_episode_length"] = stats.get("min_episode_length", 0)
    wandb.summary["eval/max_episode_length"] = stats.get("max_episode_length", 0)
    wandb.summary["eval/num_houses"] = stats.get("num_houses", 0)

    if stats.get("oracle_done_count") is not None:
        wandb.summary["eval/oracle_done_count"] = stats["oracle_done_count"]
        wandb.summary["eval/oracle_done_rate"] = stats["oracle_done_rate"]

    if stats.get("avg_successful_episode_length"):
        wandb.summary["eval/avg_successful_episode_length"] = stats["avg_successful_episode_length"]

    # Log per-house success rates as a table
    if stats.get("house_success_rates"):
        house_data = [[str(h), rate] for h, rate in sorted(stats["house_success_rates"].items())]
        house_table = wandb.Table(data=house_data, columns=["house_id", "success_rate"])
        wandb.log({"eval/house_success_rates": house_table})

    # Build result lookup by episode key
    result_by_key = {f"{r.house_id}/episode_{r.episode_idx:08d}": r for r in results}

    # Create video table with composed videos and metadata
    if composed_videos:
        # Convert EpisodeResult objects to dicts for the shared utility
        episode_data = []
        for episode_key in sorted(composed_videos.keys()):
            video_path = composed_videos[episode_key]
            result = result_by_key.get(episode_key)
            if result is None:
                continue

            episode_data.append(
                {
                    "video_path": video_path,
                    "task_description": result.task_description or "",
                    "object_name": result.object_name or "",
                    "house_id": result.house_id,
                    "episode_idx": result.episode_idx,
                    "num_steps": result.num_steps,
                    "success": result.success,
                    "oracle_done": result.oracle_done,
                    "source_episode_path": result.metadata.get("source_h5_file", ""),
                }
            )

        create_video_results_table(episode_data)

log_eval_videos_to_wandb

log_eval_videos_to_wandb(eval_dir: Path, camera_names: list[str], epoch: int)

Find and log evaluation videos to wandb.

DEPRECATED: Use log_eval_results_to_wandb with compose_episode_videos instead.

Source code in molmo_spaces/utils/eval_utils.py
def log_eval_videos_to_wandb(eval_dir: Path, camera_names: list[str], epoch: int):
    """Find and log evaluation videos to wandb.

    DEPRECATED: Use log_eval_results_to_wandb with compose_episode_videos instead.
    """
    import wandb

    if not eval_dir.exists():
        return

    # Find all matching video files
    video_files = sorted(
        set(path for cam in camera_names for path in eval_dir.glob(f"**/episode_*_{cam}*.mp4"))
    )

    if not video_files:
        return

    wandb_videos = {}
    for video_path in video_files:
        # Find matching camera name
        camera_name = next((cam for cam in camera_names if cam in video_path.name), None)
        if not camera_name:
            continue

        # Extract sub-path (e.g., "house_3") and suffix (e.g., "batch_1_of_1")
        try:
            sub_path = video_path.relative_to(eval_dir).parent.name
            sub_path = sub_path if sub_path != "." else ""
        except ValueError:
            sub_path = ""

        stem = video_path.stem
        suffix = stem[stem.find(camera_name) + len(camera_name) :].lstrip("_")

        # Build wandb key: eval/video_{sub_path}_{camera_name}_{suffix}
        key_parts = filter(None, ["video", sub_path, camera_name, suffix])
        wandb_key = f"eval/{'_'.join(key_parts)}"

        try:
            wandb_videos[wandb_key] = wandb.Video(str(video_path), format="mp4")
        except Exception as e:
            print(f"Error logging {video_path.name}: {e}")

    if wandb_videos:
        wandb.log({**wandb_videos, "epoch": epoch})

parse_obs_scene

parse_obs_scene(obs_scene_data) -> dict

Parse obs_scene from HDF5 dataset.

Source code in molmo_spaces/utils/eval_utils.py
def parse_obs_scene(obs_scene_data) -> dict:
    """Parse obs_scene from HDF5 dataset."""
    import json

    if isinstance(obs_scene_data, bytes):
        obs_scene_str = obs_scene_data.decode("utf-8")
    elif isinstance(obs_scene_data, str):
        obs_scene_str = obs_scene_data
    else:
        obs_scene_str = str(obs_scene_data)

    return json.loads(obs_scene_str)

fisheye_warping

GPU-accelerated fisheye lens distortion warping for camera images.

This module provides functions to apply fisheye distortion to images and videos, simulating the effect of wide-angle GoPro cameras. The warping is GPU-accelerated using PyTorch and uses a radial distortion model with parameters k1, k2, k3, k4.

Functions:

Name Description
apply_fisheye_warping_to_video_file

Apply fisheye warping to a video file and save the result.

calc_camera_intrinsics

Calculate camera intrinsic matrix from field of view and frame dimensions.

get_default_distortion_map

Get the default distortion map for a camera, loading from disk if necessary.

get_randomized_distortion_parameters

Get distortion parameters with random perturbations.

load_frames_from_mp4

Load frames from an MP4 video file.

make_distorted_grid

Create a distorted sampling grid for warping images.

warp_image_gpu

Apply fisheye distortion to an image using GPU acceleration.

warp_point

Warp a single point through the fisheye distortion.

warp_video_frames_batch

Apply fisheye warping to a list of video frames in batches.

warp_video_gpu

Apply fisheye distortion to a video using GPU acceleration.

apply_fisheye_warping_to_video_file

apply_fisheye_warping_to_video_file(video_path: Path | str, output_path: Path | str, K: ndarray, distortion_parameters: dict, crop_percent: float, output_shape: tuple[int, int] | None, device: device | None = None) -> bool

Apply fisheye warping to a video file and save the result.

Parameters:

Name Type Description Default
video_path Path | str

Path to input video

required
output_path Path | str

Path to save warped video

required
K ndarray

Camera intrinsic matrix

required
distortion_parameters dict

Distortion parameters

required
crop_percent float

Crop percentage after warping

required
output_shape tuple[int, int] | None

Output size (H, W) or None

required
device device | None

PyTorch device (defaults to CUDA if available)

None

Returns:

Type Description
bool

True if successful, False otherwise

Source code in molmo_spaces/utils/fisheye_warping.py
def apply_fisheye_warping_to_video_file(
    video_path: Path | str,
    output_path: Path | str,
    K: np.ndarray,
    distortion_parameters: dict,
    crop_percent: float,
    output_shape: tuple[int, int] | None,
    device: torch.device | None = None,
) -> bool:
    """Apply fisheye warping to a video file and save the result.

    Args:
        video_path: Path to input video
        output_path: Path to save warped video
        K: Camera intrinsic matrix
        distortion_parameters: Distortion parameters
        crop_percent: Crop percentage after warping
        output_shape: Output size (H, W) or None
        device: PyTorch device (defaults to CUDA if available)

    Returns:
        True if successful, False otherwise
    """
    from molmo_spaces.utils.video_utils import ffmpeg_save_video

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    try:
        # Load frames
        frames, fps = load_frames_from_mp4(video_path)

        # Warp frames
        warped_frames = warp_video_frames_batch(
            frames=frames,
            K=K,
            distortion_parameters=distortion_parameters,
            crop_percent=crop_percent,
            output_shape=output_shape,
            device=device,
        )

        # Save warped video
        ffmpeg_save_video(warped_frames, str(output_path), fps=fps, pix_fmt="rgb24")
        return True
    except Exception:
        return False

calc_camera_intrinsics

calc_camera_intrinsics(fov_y: float, frame_height: int, frame_width: int) -> ndarray

Calculate camera intrinsic matrix from field of view and frame dimensions.

Parameters:

Name Type Description Default
fov_y float

Vertical field of view in degrees

required
frame_height int

Image height in pixels

required
frame_width int

Image width in pixels

required

Returns:

Type Description
ndarray

3x3 camera intrinsic matrix K

Source code in molmo_spaces/utils/fisheye_warping.py
def calc_camera_intrinsics(fov_y: float, frame_height: int, frame_width: int) -> np.ndarray:
    """Calculate camera intrinsic matrix from field of view and frame dimensions.

    Args:
        fov_y: Vertical field of view in degrees
        frame_height: Image height in pixels
        frame_width: Image width in pixels

    Returns:
        3x3 camera intrinsic matrix K
    """
    focal_length = 0.5 * frame_height / math.tan(math.radians(fov_y / 2))
    f_x = f_y = focal_length

    c_x = frame_width / 2
    c_y = frame_height / 2
    K = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]])
    return K

get_default_distortion_map

get_default_distortion_map() -> ndarray

Get the default distortion map for a camera, loading from disk if necessary.

Source code in molmo_spaces/utils/fisheye_warping.py
def get_default_distortion_map() -> np.ndarray:
    """Get the default distortion map for a camera, loading from disk if necessary."""
    global _cached_map
    if _cached_map is None:
        map_path = "molmo_spaces/utils/constants/default_unity_distortion_map.npy"
        if not os.path.exists(map_path):
            raise FileNotFoundError(f"No default distortion map found at {map_path}.")
        _cached_map = np.load(map_path)
        # Verify map dimensions
        assert (
            _cached_map.shape[0] == GOPRO_CAMERA_HEIGHT
            and _cached_map.shape[1] == GOPRO_CAMERA_WIDTH
        ), (
            f"Default distortion map has wrong size: {_cached_map.shape}, expected: {(GOPRO_CAMERA_HEIGHT, GOPRO_CAMERA_WIDTH)}"
        )
    return _cached_map

get_randomized_distortion_parameters

get_randomized_distortion_parameters(distortion_parameters: dict | None = None, randomization_factor: float = 0.001) -> dict

Get distortion parameters with random perturbations.

Parameters:

Name Type Description Default
distortion_parameters dict | None

Base distortion parameters (uses DEFAULT if None)

None
randomization_factor float

Magnitude of random perturbation

0.001

Returns:

Type Description
dict

Dictionary of randomized distortion parameters

Source code in molmo_spaces/utils/fisheye_warping.py
def get_randomized_distortion_parameters(
    distortion_parameters: dict | None = None,
    randomization_factor: float = 0.001,
) -> dict:
    """Get distortion parameters with random perturbations.

    Args:
        distortion_parameters: Base distortion parameters (uses DEFAULT if None)
        randomization_factor: Magnitude of random perturbation

    Returns:
        Dictionary of randomized distortion parameters
    """
    if distortion_parameters is None:
        distortion_parameters = DEFAULT_DISTORTION_PARAMETERS
    randomized_distortion_parameters = {}
    for key, value in distortion_parameters.items():
        randomized_distortion_parameters[key] = value + np.random.uniform(
            -randomization_factor, randomization_factor
        )
    return randomized_distortion_parameters

load_frames_from_mp4

load_frames_from_mp4(video_path: Path | str) -> tuple[list[ndarray], float]

Load frames from an MP4 video file.

Parameters:

Name Type Description Default
video_path Path | str

Path to MP4 video file

required

Returns:

Type Description
list[ndarray]

List of frames as numpy arrays (H, W, C) in RGB format

float

FPS of the video

Source code in molmo_spaces/utils/fisheye_warping.py
def load_frames_from_mp4(video_path: Path | str) -> tuple[list[np.ndarray], float]:
    """Load frames from an MP4 video file.

    Args:
        video_path: Path to MP4 video file

    Returns:
        List of frames as numpy arrays (H, W, C) in RGB format
        FPS of the video
    """
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise ValueError(f"Could not open video file {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS)
    assert fps > 0, f"Error reading FPS from video {video_path}, got {fps}"

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)

    cap.release()
    return frames, fps

make_distorted_grid

make_distorted_grid(H: int, W: int, K: ndarray, distortion_parameters: dict, device: device | None = None, x_normalized: Tensor | None = None, y_normalized: Tensor | None = None, r: Tensor | None = None) -> Tensor

Create a distorted sampling grid for warping images.

Parameters:

Name Type Description Default
H int

Image height

required
W int

Image width

required
K ndarray

Camera intrinsic matrix (3x3)

required
distortion_parameters dict

Dict with keys k1, k2, k3, k4

required
device device | None

PyTorch device (defaults to CUDA if available)

None
x_normalized Tensor | None

Pre-computed normalized x coordinates (optional)

None
y_normalized Tensor | None

Pre-computed normalized y coordinates (optional)

None
r Tensor | None

Pre-computed radial distances (optional)

None

Returns:

Type Description
Tensor

Grid tensor of shape [1, H, W, 2] for use with grid_sample

Source code in molmo_spaces/utils/fisheye_warping.py
def make_distorted_grid(
    H: int,
    W: int,
    K: np.ndarray,
    distortion_parameters: dict,
    device: torch.device | None = None,
    x_normalized: torch.Tensor | None = None,
    y_normalized: torch.Tensor | None = None,
    r: torch.Tensor | None = None,
) -> torch.Tensor:
    """Create a distorted sampling grid for warping images.

    Args:
        H: Image height
        W: Image width
        K: Camera intrinsic matrix (3x3)
        distortion_parameters: Dict with keys k1, k2, k3, k4
        device: PyTorch device (defaults to CUDA if available)
        x_normalized: Pre-computed normalized x coordinates (optional)
        y_normalized: Pre-computed normalized y coordinates (optional)
        r: Pre-computed radial distances (optional)

    Returns:
        Grid tensor of shape [1, H, W, 2] for use with grid_sample
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if x_normalized is None or y_normalized is None or r is None:
        # Create meshgrid of pixel coordinates
        y, x = torch.meshgrid(
            torch.arange(H, device=device).float(),
            torch.arange(W, device=device).float(),
            indexing="ij",
        )

        # Normalize pixel coordinates using camera intrinsics
        x_normalized = (x - K[0, 2]) / K[0, 0]
        y_normalized = (y - K[1, 2]) / K[1, 1]

        r = torch.sqrt(x_normalized**2 + y_normalized**2)
    else:
        # Ensure the precomputed values are on the correct device
        x_normalized = x_normalized.to(device)
        y_normalized = y_normalized.to(device)
        r = r.to(device)

    # Extract distortion parameters
    k1, k2, k3, k4 = (distortion_parameters[k] for k in ["k1", "k2", "k3", "k4"])

    # Apply radial distortion
    distortion_factor = 1 + k1 * r**2 + k2 * r**4 + k3 * r**6 + k4 * r**8
    x_distorted = x_normalized * distortion_factor
    y_distorted = y_normalized * distortion_factor

    # Transform back to pixel coordinates
    x_distorted = x_distorted * K[0, 0] + K[0, 2]
    y_distorted = y_distorted * K[1, 1] + K[1, 2]

    # Normalize coordinates to [-1, 1] for grid_sample
    x_distorted = 2 * (x_distorted / (W - 1)) - 1
    y_distorted = 2 * (y_distorted / (H - 1)) - 1

    # Stack coordinates
    grid = torch.stack([x_distorted, y_distorted], dim=-1).unsqueeze(0)  # [1, H, W, 2]

    return grid

warp_image_gpu

warp_image_gpu(image: Tensor, K: ndarray | None = None, distortion_parameters: dict | None = None, crop_percent: float = DEFAULT_CROP_PERCENT, grid: Tensor | None = None, x_normalized: Tensor | None = None, y_normalized: Tensor | None = None, r: Tensor | None = None, output_shape: tuple[int, int] | None = None) -> Tensor

Apply fisheye distortion to an image using GPU acceleration.

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape [B, C, H, W]

required
K ndarray | None

Camera intrinsic matrix (required if grid is None)

None
distortion_parameters dict | None

Distortion parameters (required if grid is None)

None
crop_percent float

Percentage to crop from each edge after warping

DEFAULT_CROP_PERCENT
grid Tensor | None

Pre-computed distortion grid (optional)

None
x_normalized Tensor | None

Pre-computed normalized x coordinates (optional)

None
y_normalized Tensor | None

Pre-computed normalized y coordinates (optional)

None
r Tensor | None

Pre-computed radial distances (optional)

None
output_shape tuple[int, int] | None

Target output size (H, W) for resizing (optional)

None

Returns:

Type Description
Tensor

Warped image tensor

Source code in molmo_spaces/utils/fisheye_warping.py
def warp_image_gpu(
    image: torch.Tensor,
    K: np.ndarray | None = None,
    distortion_parameters: dict | None = None,
    crop_percent: float = DEFAULT_CROP_PERCENT,
    grid: torch.Tensor | None = None,
    x_normalized: torch.Tensor | None = None,
    y_normalized: torch.Tensor | None = None,
    r: torch.Tensor | None = None,
    output_shape: tuple[int, int] | None = None,
) -> torch.Tensor:
    """Apply fisheye distortion to an image using GPU acceleration.

    Args:
        image: Input image tensor of shape [B, C, H, W]
        K: Camera intrinsic matrix (required if grid is None)
        distortion_parameters: Distortion parameters (required if grid is None)
        crop_percent: Percentage to crop from each edge after warping
        grid: Pre-computed distortion grid (optional)
        x_normalized: Pre-computed normalized x coordinates (optional)
        y_normalized: Pre-computed normalized y coordinates (optional)
        r: Pre-computed radial distances (optional)
        output_shape: Target output size (H, W) for resizing (optional)

    Returns:
        Warped image tensor
    """
    B, C, H, W = image.shape
    assert C == 3, "Input image should have 3 channels (RGB)"

    assert H == GOPRO_CAMERA_HEIGHT and W == GOPRO_CAMERA_WIDTH, (
        f"Image should be raw GoPro format, actually {H}x{W}"
    )

    if grid is None:
        assert distortion_parameters is not None, (
            "distortion_parameters must be provided if grid is not"
        )
        assert K is not None, "K must be provided if grid is not"
        grid = make_distorted_grid(
            H,
            W,
            K,
            distortion_parameters,
            device=image.device,
            x_normalized=x_normalized,
            y_normalized=y_normalized,
            r=r,
        )
    grid = grid.repeat(B, 1, 1, 1)  # [B, H, W, 2]
    distorted_image = F.grid_sample(
        image, grid, mode="bilinear", padding_mode="zeros", align_corners=True
    )

    crop_h = int(H * crop_percent)
    crop_w = int(W * crop_percent)
    cropped_image = distorted_image[
        :, :, crop_h : -crop_h if crop_h > 0 else None, crop_w : -crop_w if crop_w > 0 else None
    ]

    if output_shape is not None:
        cropped_image = F.interpolate(
            cropped_image, size=output_shape, mode="bilinear", align_corners=True
        )

    return cropped_image

warp_point

warp_point(pixel_x: float, pixel_y: float, K: ndarray, distortion_parameters: dict, crop_percent: float, output_shape: tuple[int, int]) -> tuple[int, int]

Warp a single point through the fisheye distortion.

Parameters:

Name Type Description Default
pixel_x float

X coordinate in original image

required
pixel_y float

Y coordinate in original image

required
K ndarray

Camera intrinsic matrix

required
distortion_parameters dict

Distortion parameters

required
crop_percent float

Crop percentage used in warping

required
output_shape tuple[int, int]

Output image size (H, W)

required

Returns:

Type Description
tuple[int, int]

Tuple of (warped_x, warped_y) coordinates

Source code in molmo_spaces/utils/fisheye_warping.py
def warp_point(
    pixel_x: float,
    pixel_y: float,
    K: np.ndarray,
    distortion_parameters: dict,
    crop_percent: float,
    output_shape: tuple[int, int],
) -> tuple[int, int]:
    """Warp a single point through the fisheye distortion.

    Args:
        pixel_x: X coordinate in original image
        pixel_y: Y coordinate in original image
        K: Camera intrinsic matrix
        distortion_parameters: Distortion parameters
        crop_percent: Crop percentage used in warping
        output_shape: Output image size (H, W)

    Returns:
        Tuple of (warped_x, warped_y) coordinates
    """
    # Create a blank frame with the point marked
    blank_frame = torch.zeros((1, 3, GOPRO_CAMERA_HEIGHT, GOPRO_CAMERA_WIDTH), dtype=torch.float32)
    blank_frame[0, :, int(pixel_y), int(pixel_x)] = 1.0  # Mark the point as white

    # Warp the frame
    warped_frame = warp_image_gpu(
        blank_frame,
        K=K,
        distortion_parameters=distortion_parameters,
        crop_percent=crop_percent,
        output_shape=output_shape,
    )

    # Find the warped point
    warped_frame_np = warped_frame.squeeze().permute(1, 2, 0).cpu().numpy()
    flat_index = np.argmax(warped_frame_np[:, :, 0])
    warped_y, warped_x = np.unravel_index(flat_index, warped_frame_np.shape[:2])

    return warped_x, warped_y

warp_video_frames_batch

warp_video_frames_batch(frames: list[ndarray], K: ndarray, distortion_parameters: dict, crop_percent: float, output_shape: tuple[int, int] | None, device: device, batch_size: int = 16) -> list[ndarray]

Apply fisheye warping to a list of video frames in batches.

Parameters:

Name Type Description Default
frames list[ndarray]

List of frames as numpy arrays (H, W, C)

required
K ndarray

Camera intrinsic matrix

required
distortion_parameters dict

Distortion parameters

required
crop_percent float

Crop percentage after warping

required
output_shape tuple[int, int] | None

Output size (H, W) or None

required
device device

PyTorch device

required
batch_size int

Number of frames to process at once

16

Returns:

Type Description
list[ndarray]

List of warped frames as numpy arrays (H, W, C)

Source code in molmo_spaces/utils/fisheye_warping.py
def warp_video_frames_batch(
    frames: list[np.ndarray],
    K: np.ndarray,
    distortion_parameters: dict,
    crop_percent: float,
    output_shape: tuple[int, int] | None,
    device: torch.device,
    batch_size: int = 16,
) -> list[np.ndarray]:
    """Apply fisheye warping to a list of video frames in batches.

    Args:
        frames: List of frames as numpy arrays (H, W, C)
        K: Camera intrinsic matrix
        distortion_parameters: Distortion parameters
        crop_percent: Crop percentage after warping
        output_shape: Output size (H, W) or None
        device: PyTorch device
        batch_size: Number of frames to process at once

    Returns:
        List of warped frames as numpy arrays (H, W, C)
    """
    warped_frames = []

    # Convert K to tensor
    K_tensor = torch.tensor(K, dtype=torch.float32, device=device)

    # Process in batches for efficiency
    for i in range(0, len(frames), batch_size):
        batch = frames[i : i + batch_size]

        # Stack frames into batch tensor [B, H, W, C]
        batch_array = np.stack(batch, axis=0)
        batch_tensor = torch.from_numpy(batch_array).float().to(device) / 255.0

        # Permute to [B, C, H, W]
        batch_tensor = batch_tensor.permute(0, 3, 1, 2)

        # Apply warping
        warped_batch = warp_image_gpu(
            image=batch_tensor,
            K=K_tensor,
            distortion_parameters=distortion_parameters,
            crop_percent=crop_percent,
            output_shape=output_shape,
        )

        # Convert back to numpy and uint8
        warped_batch = (warped_batch.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)

        # Add to results
        for frame in warped_batch:
            warped_frames.append(frame)

    return warped_frames

warp_video_gpu

warp_video_gpu(video: ndarray | Tensor, K: ndarray | None = None, randomize_distortion_parameters: bool = False, crop_percent: float = DEFAULT_CROP_PERCENT, output_shape: tuple[int, int] | None = None) -> ndarray

Apply fisheye distortion to a video using GPU acceleration.

Parameters:

Name Type Description Default
video ndarray | Tensor

Input video as numpy array [T, H, W, C] or tensor

required
K ndarray | None

Camera intrinsic matrix (computed from defaults if None)

None
randomize_distortion_parameters bool

Whether to randomize distortion params

False
crop_percent float

Percentage to crop from each edge after warping

DEFAULT_CROP_PERCENT
output_shape tuple[int, int] | None

Target output size (H, W) for resizing (optional)

None

Returns:

Type Description
ndarray

Warped video as numpy array [T, H, W, C] with uint8 values

Source code in molmo_spaces/utils/fisheye_warping.py
def warp_video_gpu(
    video: np.ndarray | torch.Tensor,
    K: np.ndarray | None = None,
    randomize_distortion_parameters: bool = False,
    crop_percent: float = DEFAULT_CROP_PERCENT,
    output_shape: tuple[int, int] | None = None,
) -> np.ndarray:
    """Apply fisheye distortion to a video using GPU acceleration.

    Args:
        video: Input video as numpy array [T, H, W, C] or tensor
        K: Camera intrinsic matrix (computed from defaults if None)
        randomize_distortion_parameters: Whether to randomize distortion params
        crop_percent: Percentage to crop from each edge after warping
        output_shape: Target output size (H, W) for resizing (optional)

    Returns:
        Warped video as numpy array [T, H, W, C] with uint8 values
    """
    assert video.shape[2] == GOPRO_CAMERA_WIDTH and video.shape[1] == GOPRO_CAMERA_HEIGHT, (
        "Video should be raw GoPro format"
    )

    if randomize_distortion_parameters:
        distortion_parameters = get_randomized_distortion_parameters()
    else:
        distortion_parameters = DEFAULT_DISTORTION_PARAMETERS

    if K is None:
        K = calc_camera_intrinsics(GOPRO_VERTICAL_FOV, GOPRO_CAMERA_HEIGHT, GOPRO_CAMERA_WIDTH)

    # Convert to tensor if needed
    if not isinstance(video, torch.Tensor):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        video_tensor = torch.from_numpy(video).float().to(device) / 255.0
    else:
        video_tensor = video.float() / 255.0

    # Permute to [B, C, H, W] format
    video_tensor = video_tensor.permute(0, 3, 1, 2)

    warped_video = warp_image_gpu(
        image=video_tensor,
        K=K,
        distortion_parameters=distortion_parameters,
        crop_percent=crop_percent,
        output_shape=output_shape,
    )
    # Convert back to numpy
    warped_video = (warped_video.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
    return warped_video

function_utils

Functions:

Name Description
make_lenient

Wrap func so extra args/kwargs are silently dropped.

make_lenient

make_lenient(func: Callable) -> Callable

Wrap func so extra args/kwargs are silently dropped.

Args matching a declared parameter are forwarded; leftover positional and keyword args flow into func's *args / **kwargs when present, and are dropped otherwise. Raises TypeError on double-binding and (via Python's normal call machinery) on missing required parameters.

Parameters:

Name Type Description Default
func Callable

The function (or class) to wrap.

required

Returns:

Type Description
Callable

A picklable callable.

Note

Positional-only parameters (after /) are not supported.

Source code in molmo_spaces/utils/function_utils.py
def make_lenient(func: Callable) -> Callable:
    """
    Wrap ``func`` so extra args/kwargs are silently dropped.

    Args matching a declared parameter are forwarded; leftover positional and
    keyword args flow into ``func``'s ``*args`` / ``**kwargs`` when present, and
    are dropped otherwise. Raises ``TypeError`` on double-binding and (via
    Python's normal call machinery) on missing required parameters.

    Args:
        func: The function (or class) to wrap.

    Returns:
        A picklable callable.

    Note:
        Positional-only parameters (after ``/``) are not supported.
    """
    return _LenientCallable(func)

grasp_sample

This module contains functionality for filtering and sampling grasps based on heuristics.

Functions:

Name Description
add_grasp_collision_bodies

Add grasp collision bodies to the scene.

get_feasible_grasp_idx
get_grasp_collision_body_name
get_noncolliding_grasp_mask
select_grasp_pose

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

add_grasp_collision_bodies

add_grasp_collision_bodies(spec: MjSpec, num_grasps: int, grasp_width: float, grasp_length: float, grasp_height: float, grasp_base_pos: ndarray)

Add grasp collision bodies to the scene.

Source code in molmo_spaces/utils/grasp_sample.py
def add_grasp_collision_bodies(
    spec: mujoco.MjSpec,
    num_grasps: int,
    grasp_width: float,
    grasp_length: float,
    grasp_height: float,
    grasp_base_pos: np.ndarray,
):
    """Add grasp collision bodies to the scene."""
    for i in range(num_grasps):
        # init grasp bodies in the sky (below the ground causes collision with the floor)
        grasp_body = spec.worldbody.add_body(
            name=get_grasp_collision_body_name(i),
            pos=[0, 0, 10],
            gravcomp=1.0,
        )
        grasp_body.add_freejoint()

        geom_kwargs = dict(
            type=mujoco.mjtGeom.mjGEOM_CYLINDER,
            rgba=[0, 0, 1, 1],
            group=3,
            contype=0,
            conaffinity=0b1111,
        )

        base_geom = grasp_body.add_geom(**geom_kwargs)
        base_geom.size[0] = grasp_height / 2
        base_geom.fromto[:3] = np.array([0, -grasp_width / 2, 0]) + grasp_base_pos
        base_geom.fromto[3:] = np.array([0, grasp_width / 2, 0]) + grasp_base_pos

        finger1_geom = grasp_body.add_geom(**geom_kwargs)
        finger1_geom.size[0] = grasp_height / 2
        finger1_geom.fromto[:3] = np.array([0, -grasp_width / 2, 0]) + grasp_base_pos
        finger1_geom.fromto[3:] = np.array([0, -grasp_width / 2, grasp_length]) + grasp_base_pos

        finger2_geom = grasp_body.add_geom(**geom_kwargs)
        finger2_geom.size[0] = grasp_height / 2
        finger2_geom.fromto[:3] = np.array([0, grasp_width / 2, 0]) + grasp_base_pos
        finger2_geom.fromto[3:] = np.array([0, grasp_width / 2, grasp_length]) + grasp_base_pos

get_feasible_grasp_idx

get_feasible_grasp_idx(mg_id: str, robot: Robot, grasp_poses_world: ndarray, n_ik_checks: int, ik_batch_size: int)
Source code in molmo_spaces/utils/grasp_sample.py
def get_feasible_grasp_idx(
    mg_id: str,
    robot: Robot,
    grasp_poses_world: np.ndarray,
    n_ik_checks: int,
    ik_batch_size: int,
):
    n_checks_done = 0
    ret: int | None = None

    with Timer() as ik_check_time:
        for i in range(0, n_ik_checks, ik_batch_size):
            grasps = grasp_poses_world[i : i + ik_batch_size]
            n_checks_done += len(grasps)
            real_batch_size = len(grasps)

            if real_batch_size < ik_batch_size:
                # pad to batch size to avoid triggering recompilation
                grasps = np.concatenate(
                    [grasps, np.broadcast_to(grasps[-1:], (ik_batch_size - real_batch_size, 4, 4))]
                )

            ik_result = robot.parallel_kinematics.ik(
                mg_id,
                grasps,
                None,
                robot.robot_view.get_qpos_dict(),
                robot.robot_view.base.pose,
                rel_to_base=False,
            )
            for j, result in enumerate(ik_result[:real_batch_size]):
                if result is not None:
                    ret = i + j
                    break
            if ret is not None:
                break
    log.info(
        f"Feasibility-checked {n_checks_done} grasps in {ik_check_time.value:.3f}s, found feasible grasp: {ret is not None}"
    )

    return ret

get_grasp_collision_body_name

get_grasp_collision_body_name(grasp_idx: int) -> str
Source code in molmo_spaces/utils/grasp_sample.py
def get_grasp_collision_body_name(grasp_idx: int) -> str:
    return f"grasp_collision_{grasp_idx}"

get_noncolliding_grasp_mask

get_noncolliding_grasp_mask(mj_model: MjModel, mj_data: MjData, grasp_poses_world: ndarray, batch_size: int) -> ndarray
Source code in molmo_spaces/utils/grasp_sample.py
def get_noncolliding_grasp_mask(
    mj_model: mujoco.MjModel,
    mj_data: mujoco.MjData,
    grasp_poses_world: np.ndarray,
    batch_size: int,
) -> np.ndarray:
    n_grasps = len(grasp_poses_world)
    grasp_bodies = [
        create_mlspaces_body(mj_data, get_grasp_collision_body_name(i)) for i in range(batch_size)
    ]
    start_poses = [body.pose.copy() for body in grasp_bodies]
    grasp_body_ids = set(body.body_id for body in grasp_bodies)

    try:
        colliding_grasp_mask = np.zeros(n_grasps, dtype=bool)
        for i in range(0, n_grasps, batch_size):
            grasp_bid_to_idx = {}
            n_grasps_in_batch = min(batch_size, n_grasps - i)
            for j in range(i, i + n_grasps_in_batch):
                grasp_body = grasp_bodies[j - i]
                grasp_body.pose = grasp_poses_world[j]
                grasp_bid_to_idx[grasp_body.body_id] = j
            for j in range(n_grasps_in_batch, len(grasp_bodies)):
                grasp_bodies[j].pose = start_poses[j]

            mujoco.mj_kinematics(mj_model, mj_data)
            mujoco.mj_collision(mj_model, mj_data)
            for contact in mj_data.contact:
                bid1 = mj_model.geom_bodyid[contact.geom1]
                bid2 = mj_model.geom_bodyid[contact.geom2]
                if bid1 in grasp_body_ids or bid2 in grasp_body_ids:
                    grasp_bid = bid1 if bid1 in grasp_body_ids else bid2
                    other_bid = bid2 if grasp_bid == bid1 else bid1
                    assert other_bid not in grasp_body_ids

                    grasp_idx = grasp_bid_to_idx[grasp_bid]
                    colliding_grasp_mask[grasp_idx] = True

        return ~colliding_grasp_mask
    finally:
        # move the grasp bodies back out of the way
        for body, pose in zip(grasp_bodies, start_poses):
            body.pose = pose
        mujoco.mj_fwdPosition(mj_model, mj_data)

select_grasp_pose

select_grasp_pose(env: CPUMujocoEnv, grasp_poses_world: ndarray, object_pose: ndarray, check_collision: bool, n_collision_checks: int, collision_batch_size: int, check_ik: bool, n_ik_checks: int, ik_batch_size: int, pos_cost_weight: float = 1.0, rot_cost_weight: float = 0.01, vertical_cost_weight: float = 2.0, horizontal_cost_weight: float = 0, com_dist_cost_weight: float = 8.0) -> ndarray
Source code in molmo_spaces/utils/grasp_sample.py
def select_grasp_pose(
    env: CPUMujocoEnv,
    grasp_poses_world: np.ndarray,
    object_pose: np.ndarray,
    check_collision: bool,
    n_collision_checks: int,
    collision_batch_size: int,
    check_ik: bool,
    n_ik_checks: int,
    ik_batch_size: int,
    pos_cost_weight: float = 1.0,
    rot_cost_weight: float = 0.01,
    vertical_cost_weight: float = 2.0,
    horizontal_cost_weight: float = 0,
    com_dist_cost_weight: float = 8.0,
) -> np.ndarray:
    robot = env.current_robot
    gripper_mg_id = robot.robot_view.get_gripper_movegroup_ids()[0]
    tcp_pose = robot.robot_view.get_move_group(gripper_mg_id).leaf_frame_to_world
    tcp_pose_inv = np.linalg.inv(tcp_pose)

    dist_tcp = tcp_pose_inv @ grasp_poses_world  # shape (N,4,4)
    dists_tcp_p = np.linalg.norm(dist_tcp[:, :3, 3], axis=1)
    dist_tcp_o = R.from_matrix(dist_tcp[:, :3, :3]).magnitude() * 180 / np.pi

    dists_up = grasp_poses_world[:, 2, 2]  # range = [-1, 1]

    dists_com = np.linalg.norm((np.linalg.inv(object_pose) @ grasp_poses_world)[:, :3, 3], axis=1)

    # Cost for horizontal orientation: 0 = perfectly horizontal (z-axis parallel to XY plane), 1 = vertical
    # Lower cost = more horizontal, so we want to minimize this
    # Use squared term to more strongly penalize vertical orientations
    dists_xy_parallel = np.abs(dists_up) ** 2

    dist_total = (
        pos_cost_weight * dists_tcp_p
        + rot_cost_weight * dist_tcp_o
        + vertical_cost_weight * dists_up
        + horizontal_cost_weight * dists_xy_parallel
        + com_dist_cost_weight * dists_com
    )
    close_grasp_ids = np.argsort(dist_total, kind="stable")  # weight positions and orientations
    close_grasp_ids = close_grasp_ids[:n_collision_checks]

    # filter for noncolliding grasps
    if check_collision:
        with Timer() as collision_check_time:
            noncolliding_grasp_mask = get_noncolliding_grasp_mask(
                env.current_model,
                env.current_data,
                grasp_poses_world[close_grasp_ids],
                collision_batch_size,
            )

        log.info(
            f"Collision-checked {len(close_grasp_ids)} grasps in {collision_check_time.value:.3f}s, found {np.sum(noncolliding_grasp_mask)} non-colliding grasps"
        )
    else:
        noncolliding_grasp_mask = np.ones(len(close_grasp_ids), dtype=bool)

    noncolliding_close_grasp_ids = close_grasp_ids[noncolliding_grasp_mask]
    # colliding_close_grasp_ids = close_grasp_ids[~noncolliding_grasp_mask]

    # filter for feasibility/reachability
    if check_ik:
        grasp_idx: int | None = None

        if noncolliding_close_grasp_ids.size > 0:
            noncolliding_grasp_idx = get_feasible_grasp_idx(
                gripper_mg_id,
                robot,
                grasp_poses_world[noncolliding_close_grasp_ids],
                n_ik_checks,
                ik_batch_size,
            )
            if noncolliding_grasp_idx is not None:
                grasp_idx = noncolliding_close_grasp_ids[noncolliding_grasp_idx]
    elif noncolliding_close_grasp_ids.size > 0:
        grasp_idx = int(noncolliding_close_grasp_ids[0])
    else:
        grasp_idx = None

    if grasp_idx is None:
        raise ValueError("No feasible grasp found")

    return grasp_poses_world[grasp_idx]

grasps

This module contains functionality for loading grasps from registered grasp libraries.

Note: this module caches aggressively, so grasp/asset libraries must be registered before the first call into this module. New registrations will not be visible until the caches are cleared.

Functions:

Name Description
flip_grasps
get_grasp_libraries_for_object
get_joint_grasp_path
get_joint_grasps

Load the first available joint grasps for a given object and joint in the world frame.

get_pickup_grasp_path
get_pickup_grasps

Load the first available pickup grasps for a given object in the world frame.

has_joint_grasp_path
has_pickup_grasp_path
has_valid_joint_grasps
has_valid_pickup_grasps
load_joint_grasps

Load the first available joint grasps for a given object and joint in the joint's local frame.

load_pickup_grasps

Load the first available pickup grasps for a given object in the local frame.

sanitize_grasp_library_list_and_cache

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

flip_grasps

flip_grasps(grasps: ndarray) -> ndarray
Source code in molmo_spaces/utils/grasps.py
def flip_grasps(grasps: np.ndarray) -> np.ndarray:
    flip = np.eye(4)
    flip[:3, :3] = R.from_euler("z", 180, degrees=True).as_matrix()
    return grasps @ flip

get_grasp_libraries_for_object

get_grasp_libraries_for_object(uid: str) -> list[str]
Source code in molmo_spaces/utils/grasps.py
def get_grasp_libraries_for_object(uid: str) -> list[str]:
    package, _, _ = _locate_uid_package(uid)
    if package not in OBJECT_LIBRARY_TO_GRASP_LIBRARIES:
        return []
    return list(OBJECT_LIBRARY_TO_GRASP_LIBRARIES[package])

get_joint_grasp_path

get_joint_grasp_path(uid: str, joint_name: str, grasp_libraries: Sequence[str] | None = None) -> Path | None
Source code in molmo_spaces/utils/grasps.py
def get_joint_grasp_path(
    uid: str, joint_name: str, grasp_libraries: Sequence[str] | None = None
) -> Path | None:
    # If we only specify one grasp library, just use it and fail later if not found.
    # In general we shouldn't do this, but thor articulated objects can't be looked
    # up by uid (for whatever reason) so this serves as a workaround by skipping the lookup.
    # Client code doing articulated object manipulation with thor should only specify one grasp library.
    if grasp_libraries is not None and len(grasp_libraries) == 1:
        libs = grasp_libraries
    else:
        libs = _filter_grasp_libraries_for_object(uid, grasp_libraries)

    for library in libs:
        if library in USER_GRASP_LIBRARIES:
            grasp_library_dir = USER_GRASP_LIBRARIES[library]
            grasp_library_index = get_user_grasp_library_index(grasp_library_dir)
            robot_name = library.split("/", 1)[-1]
            grasp_file = (
                grasp_library_index.articulated_grasp_paths.get(robot_name, {})
                .get(uid, {})
                .get(joint_name, None)
            )
            if grasp_file is not None:
                grasp_file = grasp_library_dir / grasp_file
        else:
            # droid (thor) is the only builtin grasp library with joint grasps
            grasp_file = ASSETS_DIR / f"grasps/droid/{uid}/{joint_name}_grasps_filtered.npz"

        if grasp_file is not None and grasp_file.exists():
            return grasp_file

    return None

get_joint_grasps

get_joint_grasps(env: CPUMujocoEnv, obj: MlSpacesArticulationObject, joint_idx: int, include_flipped: bool = True, grasp_libraries: list[str] | None = None) -> tuple[ndarray, ndarray]

Load the first available joint grasps for a given object and joint in the world frame.

Parameters:

Name Type Description Default
env CPUMujocoEnv

The environment

required
obj MlSpacesArticulationObject

The object

required
joint_idx int

The index of the joint

required
include_flipped bool

Whether to include flipped grasps

True
grasp_libraries list[str] | None

The grasp libraries to use (defaults to all available libraries for the object)

None

Returns:

Type Description
ndarray

Numpy array of shape (N, 4, 4) containing the grasp poses in the world frame.

ndarray

Numpy array of shape (4, 4) containing the joint body pose in the world frame.

Source code in molmo_spaces/utils/grasps.py
def get_joint_grasps(
    env: CPUMujocoEnv,
    obj: MlSpacesArticulationObject,
    joint_idx: int,
    include_flipped: bool = True,
    grasp_libraries: list[str] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Load the first available joint grasps for a given object and joint in the world frame.

    Args:
        env: The environment
        obj: The object
        joint_idx: The index of the joint
        include_flipped: Whether to include flipped grasps
        grasp_libraries: The grasp libraries to use (defaults to all available libraries for the object)

    Returns:
        Numpy array of shape (N, 4, 4) containing the grasp poses in the world frame.
        Numpy array of shape (4, 4) containing the joint body pose in the world frame.
    """
    scene_metadata = env.current_scene_metadata
    if scene_metadata is None:
        raise ValueError(f"Could not load grasps for object {obj.name}: No scene metadata found!")
    if obj.name not in scene_metadata["objects"]:
        raise ValueError(
            f"Could not load grasps for object {obj.name}: Object not found in scene metadata!"
        )

    joint_name: str = obj.joint_names[joint_idx]
    asset_joint_name = scene_metadata["objects"][obj.name]["name_map"]["joints"][joint_name]
    asset_id = scene_metadata["objects"][obj.name]["asset_id"]

    grasps = load_joint_grasps(asset_id, asset_joint_name, grasp_libraries, num_grasps=int(1e6))
    if len(grasps) == 0:
        raise ValueError(f"No grasps found for {obj.name}/{joint_name}")

    joint_bodyid = env.current_model.joint(joint_name).bodyid.item()
    joint_body_pose = np.eye(4)
    joint_body_pose[:3, 3] = env.current_data.xpos[joint_bodyid]
    joint_body_pose[:3, :3] = env.current_data.xmat[joint_bodyid].reshape(3, 3)

    grasps_world = joint_body_pose @ grasps
    if include_flipped:
        all_grasp_poses = np.concatenate([grasps_world, flip_grasps(grasps_world)])
    else:
        all_grasp_poses = grasps_world
    log.info(
        f"Loaded {len(all_grasp_poses)} total grasp poses"
        + (" (including flipped versions)" if include_flipped else "")
    )
    return all_grasp_poses, joint_body_pose

get_pickup_grasp_path

get_pickup_grasp_path(uid: str, grasp_libraries: Sequence[str] | None = None) -> Path | None
Source code in molmo_spaces/utils/grasps.py
def get_pickup_grasp_path(uid: str, grasp_libraries: Sequence[str] | None = None) -> Path | None:
    libs = _filter_grasp_libraries_for_object(uid, grasp_libraries)

    for library in libs:
        if library in USER_GRASP_LIBRARIES:
            grasp_library_dir = USER_GRASP_LIBRARIES[library]
            grasp_library_index = get_user_grasp_library_index(grasp_library_dir)
            robot_name = library.split("/", 1)[-1]
            grasp_file = grasp_library_index.grasp_paths.get(robot_name, {}).get(uid, None)
            if grasp_file is not None:
                grasp_file = grasp_library_dir / grasp_file
        else:
            grasp_file = ASSETS_DIR / f"grasps/{library}/{uid}/{uid}_grasps_filtered.npz"

        if grasp_file is not None and grasp_file.exists():
            return grasp_file

    return None

get_pickup_grasps

get_pickup_grasps(env: CPUMujocoEnv, obj: MlSpacesObject, include_flipped: bool = True, grasp_libraries: list[str] | None = None) -> ndarray

Load the first available pickup grasps for a given object in the world frame.

Parameters:

Name Type Description Default
env CPUMujocoEnv

The environment

required
obj MlSpacesObject

The object

required
include_flipped bool

Whether to include flipped grasps

True
grasp_libraries list[str] | None

The grasp libraries to use (defaults to all available libraries for the object)

None

Returns:

Type Description
ndarray

A numpy array of shape (N, 4, 4) containing the grasp poses in the world frame.

Source code in molmo_spaces/utils/grasps.py
def get_pickup_grasps(
    env: CPUMujocoEnv,
    obj: MlSpacesObject,
    include_flipped: bool = True,
    grasp_libraries: list[str] | None = None,
) -> np.ndarray:
    """
    Load the first available pickup grasps for a given object in the world frame.

    Args:
        env: The environment
        obj: The object
        include_flipped: Whether to include flipped grasps
        grasp_libraries: The grasp libraries to use (defaults to all available libraries for the object)

    Returns:
        A numpy array of shape (N, 4, 4) containing the grasp poses in the world frame.
    """
    scene_metadata = env.current_scene_metadata
    if scene_metadata is None:
        raise ValueError(f"Could not load grasps for object {obj.name}: No scene metadata found!")
    if obj.name not in scene_metadata["objects"]:
        raise ValueError(
            f"Could not load grasps for object {obj.name}: Object not found in scene metadata!"
        )

    asset_id: str = scene_metadata["objects"][obj.name]["asset_id"]
    grasps = load_pickup_grasps(asset_id, grasp_libraries, num_grasps=int(1e6))
    if len(grasps) == 0:
        raise ValueError(f"No grasps found for {obj.name}")

    grasps_world = obj.pose @ grasps
    if include_flipped:
        all_grasp_poses = np.concatenate([grasps_world, flip_grasps(grasps_world)])
    else:
        all_grasp_poses = grasps_world

    log.info(
        f"Loaded {len(all_grasp_poses)} total grasp poses"
        + (" (including flipped versions)" if include_flipped else "")
    )
    return all_grasp_poses

has_joint_grasp_path

has_joint_grasp_path(uid: str, joint_name: str, grasp_libraries: Sequence[str] | None = None) -> bool
Source code in molmo_spaces/utils/grasps.py
@sanitize_grasp_library_list_and_cache(cache_size=10000)
def has_joint_grasp_path(
    uid: str, joint_name: str, grasp_libraries: Sequence[str] | None = None
) -> bool:
    return get_joint_grasp_path(uid, joint_name, grasp_libraries) is not None

has_pickup_grasp_path

has_pickup_grasp_path(uid: str, grasp_libraries: Sequence[str] | None = None) -> bool
Source code in molmo_spaces/utils/grasps.py
@sanitize_grasp_library_list_and_cache(cache_size=10000)
def has_pickup_grasp_path(uid: str, grasp_libraries: Sequence[str] | None = None) -> bool:
    return get_pickup_grasp_path(uid, grasp_libraries) is not None

has_valid_joint_grasps

has_valid_joint_grasps(uid: str, joint_name: str, num_grasps: int = 1, grasp_libraries: Sequence[str] | None = None) -> bool
Source code in molmo_spaces/utils/grasps.py
@sanitize_grasp_library_list_and_cache(cache_size=10000)
def has_valid_joint_grasps(
    uid: str,
    joint_name: str,
    num_grasps: int = 1,
    grasp_libraries: Sequence[str] | None = None,
) -> bool:
    grasp_path = get_joint_grasp_path(uid, joint_name, grasp_libraries)
    if grasp_path is None:
        return False

    # read the number of grasps from the grasp file without loading the entire file into memory
    with zipfile.ZipFile(grasp_path) as zf:
        with zf.open("transforms.npy") as f:
            version = np.lib.format.read_magic(f)
            if version[0] == 1:
                shape, _, _ = np.lib.format.read_array_header_1_0(f)
            else:
                shape, _, _ = np.lib.format.read_array_header_2_0(f)
            return shape[0] >= num_grasps

has_valid_pickup_grasps

has_valid_pickup_grasps(uid: str, num_grasps: int = 1, grasp_libraries: Sequence[str] | None = None) -> bool
Source code in molmo_spaces/utils/grasps.py
@sanitize_grasp_library_list_and_cache(cache_size=10000)
def has_valid_pickup_grasps(
    uid: str, num_grasps: int = 1, grasp_libraries: Sequence[str] | None = None
) -> bool:
    grasp_path = get_pickup_grasp_path(uid, grasp_libraries)
    if grasp_path is None:
        return False

    # read the number of grasps from the grasp file without loading the entire file into memory
    with zipfile.ZipFile(grasp_path) as zf:
        with zf.open("transforms.npy") as f:
            version = np.lib.format.read_magic(f)
            if version[0] == 1:
                shape, _, _ = np.lib.format.read_array_header_1_0(f)
            else:
                shape, _, _ = np.lib.format.read_array_header_2_0(f)
            return shape[0] >= num_grasps

load_joint_grasps

load_joint_grasps(uid: str, joint_name: str, grasp_libraries: list[str] | None = None, num_grasps: int = 50) -> ndarray

Load the first available joint grasps for a given object and joint in the joint's local frame.

Parameters:

Name Type Description Default
uid str

The asset ID of the object

required
joint_name str

The name of the joint

required
grasp_libraries list[str] | None

The grasp libraries to use (defaults to all available libraries for the object)

None
num_grasps int

The maximum number of grasps to load

50

Returns:

Type Description
ndarray

A numpy array of shape (N, 4, 4) containing the grasp poses in the joint's local frame.

Source code in molmo_spaces/utils/grasps.py
def load_joint_grasps(
    uid: str, joint_name: str, grasp_libraries: list[str] | None = None, num_grasps: int = 50
) -> np.ndarray:
    """
    Load the first available joint grasps for a given object and joint in the joint's local frame.

    Args:
        uid: The asset ID of the object
        joint_name: The name of the joint
        grasp_libraries: The grasp libraries to use (defaults to all available libraries for the object)
        num_grasps: The maximum number of grasps to load

    Returns:
        A numpy array of shape (N, 4, 4) containing the grasp poses in the joint's local frame.
    """
    grasp_path = get_joint_grasp_path(uid, joint_name, grasp_libraries)
    if grasp_path is None:
        raise ValueError(f"No joint grasp file found for {uid}/{joint_name}")

    npz_data = np.load(grasp_path)
    transforms: np.ndarray = npz_data["transforms"]
    if len(transforms) <= num_grasps:
        return transforms
    else:
        idxs = random.sample(range(len(transforms)), num_grasps)
        return transforms[idxs]

load_pickup_grasps

load_pickup_grasps(uid: str, grasp_libraries: list[str] | None = None, num_grasps: int = 50) -> ndarray

Load the first available pickup grasps for a given object in the local frame.

Parameters:

Name Type Description Default
uid str

The asset ID of the object

required
grasp_libraries list[str] | None

The grasp libraries to use (defaults to all available libraries for the object)

None
num_grasps int

The maximum number of grasps to load

50

Returns:

Type Description
ndarray

A numpy array of shape (N, 4, 4) containing the grasp poses in the local frame

Source code in molmo_spaces/utils/grasps.py
def load_pickup_grasps(
    uid: str, grasp_libraries: list[str] | None = None, num_grasps: int = 50
) -> np.ndarray:
    """
    Load the first available pickup grasps for a given object in the local frame.

    Args:
        uid: The asset ID of the object
        grasp_libraries: The grasp libraries to use (defaults to all available libraries for the object)
        num_grasps: The maximum number of grasps to load

    Returns:
        A numpy array of shape (N, 4, 4) containing the grasp poses in the local frame
    """
    grasp_path = get_pickup_grasp_path(uid, grasp_libraries)
    if grasp_path is None:
        raise ValueError(f"No grasp file found for {uid}")

    npz_data = np.load(grasp_path)
    transforms: np.ndarray = npz_data["transforms"]
    if len(transforms) <= num_grasps:
        return transforms
    else:
        idxs = random.sample(range(len(transforms)), num_grasps)
        return transforms[idxs]

sanitize_grasp_library_list_and_cache

sanitize_grasp_library_list_and_cache(cache_size: int)
Source code in molmo_spaces/utils/grasps.py
def sanitize_grasp_library_list_and_cache(cache_size: int):
    def decorator(func):
        sig = inspect.signature(func)

        @lru_cache(maxsize=cache_size)
        def cached(*args, **kwargs):
            return func(*args, **kwargs)

        @wraps(func)
        def wrapper(*args, **kwargs):
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()
            libs = bound.arguments.get("grasp_libraries")
            if libs is not None and not isinstance(libs, tuple):
                bound.arguments["grasp_libraries"] = tuple(libs)
            return cached(*bound.args, **bound.kwargs)

        wrapper.cache_info = cached.cache_info  # type: ignore[attr-defined]
        wrapper.cache_clear = cached.cache_clear  # type: ignore[attr-defined]
        return wrapper

    return decorator

lazy_loading_utils

Classes:

Name Description
UserAssetLibraryIndexEntry
UserGraspLibraryIndex

Functions:

Name Description
add_install_prefixes
debug_lazy_search
find_object_paths
get_thor_uid_to_xmls
get_user_grasp_library_index
get_user_library_index
install_grasps_for_scene
install_objects_for_scene
install_scene_from_path
install_scene_from_source_index
install_scene_with_objects_and_grasps_from_path
install_uid
locate_uid_package

Locate the package containing the given object UID.

Attributes:

Name Type Description
UserAssetLibraryIndex

UserAssetLibraryIndex module-attribute

UserAssetLibraryIndex = TypeAdapter(dict[str, UserAssetLibraryIndexEntry])

UserAssetLibraryIndexEntry

Bases: BaseModel

Attributes:

Name Type Description
metadata_npz_path Path | None
metadata_path Path
object_path Path
uid str
metadata_npz_path instance-attribute
metadata_npz_path: Path | None
metadata_path instance-attribute
metadata_path: Path
object_path instance-attribute
object_path: Path
uid instance-attribute
uid: str

UserGraspLibraryIndex

Bases: BaseModel

Attributes:

Name Type Description
articulated_grasp_paths dict[str, dict[str, dict[str, Path]]]
grasp_paths dict[str, dict[str, Path]]
articulated_grasp_paths class-attribute instance-attribute
articulated_grasp_paths: dict[str, dict[str, dict[str, Path]]] = Field(default_factory=dict)
grasp_paths class-attribute instance-attribute
grasp_paths: dict[str, dict[str, Path]] = Field(default_factory=dict)

add_install_prefixes

add_install_prefixes(data_type, source, relative_path)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def add_install_prefixes(data_type, source, relative_path):
    return ASSETS_DIR / data_type / source / relative_path
debug_lazy_search()
Source code in molmo_spaces/utils/lazy_loading_utils.py
def debug_lazy_search():
    uid = "0000c32fde7f45efb8d14e8ba737d50c"
    source, package, xml_path = locate_uid_package(uid)
    print(uid, source, package, xml_path)
    install_uid(uid)

    uid = "Bowl_1"
    source, package, xml_path = locate_uid_package(uid)
    print(uid, source, package, xml_path)
    install_uid(uid)

    print("DONE")

find_object_paths

find_object_paths(xml_path, exclude_thor=True)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def find_object_paths(xml_path, exclude_thor=True):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    scene_dir = Path(xml_path).parent

    # Find all <asset> child elements
    for asset_type in ["mesh", "texture", "material", "hfield", "skin"]:
        for elem in root.findall(f".//asset/{asset_type}"):
            # Most assets store the file path in the 'file' attribute
            file_path = elem.attrib.get("file")
            if file_path and file_path.startswith("../"):
                if not exclude_thor or "/objects/thor/" not in file_path:
                    # Objects are globally linked from the cache
                    full_path = (scene_dir / file_path).resolve()
                    source = (
                        full_path.relative_to(get_resource_manager().cache_dir / "objects")
                    ).parts[0]
                    rel_asset = full_path.relative_to(
                        get_resource_manager().source_dir("objects", source)
                    )
                    yield source, rel_asset

get_thor_uid_to_xmls cached

get_thor_uid_to_xmls() -> dict[str, Path]
Source code in molmo_spaces/utils/lazy_loading_utils.py
@cache
def get_thor_uid_to_xmls() -> dict[str, Path]:
    base = (ASSETS_DIR / "objects" / "thor").resolve()
    return {xml.stem: xml for xml in base.rglob("*.xml")}

get_user_grasp_library_index cached

get_user_grasp_library_index(user_library_path: Path)
Source code in molmo_spaces/utils/lazy_loading_utils.py
@cache
def get_user_grasp_library_index(user_library_path: Path):
    with open(user_library_path / "grasps_index.json", "r") as f:
        return UserGraspLibraryIndex.model_validate_json(f.read())

get_user_library_index cached

get_user_library_index(user_library_path: Path)
Source code in molmo_spaces/utils/lazy_loading_utils.py
@cache
def get_user_library_index(user_library_path: Path):
    with open(user_library_path / "assets_index.json", "r") as f:
        return UserAssetLibraryIndex.validate_json(f.read())

install_grasps_for_scene

install_grasps_for_scene(xml_path, grasp_source='droid_objaverse', exclude_thor=True)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_grasps_for_scene(xml_path, grasp_source="droid_objaverse", exclude_thor=True):
    if grasp_source in ["droid", "rum"]:
        # These are thor-only
        if exclude_thor:
            return {}

    if grasp_source not in DATA_TYPE_TO_SOURCE_TO_VERSION["grasps"]:
        return {}

    source_to_archives = {grasp_source: set()}

    for _source, rel_asset in find_object_paths(xml_path, exclude_thor=exclude_thor):
        for substr in split_query_tokens(rel_asset.name):
            source_to_archives[grasp_source].update(
                get_resource_manager().index_lookup("grasps", grasp_source, substr)
            )

    get_resource_manager().install_packages("grasps", source_to_archives)

    return source_to_archives

install_objects_for_scene

install_objects_for_scene(xml_path, exclude_thor=True)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_objects_for_scene(xml_path, exclude_thor=True):
    if "objaverse" not in DATA_TYPE_TO_SOURCE_TO_VERSION["objects"]:
        return {}

    source_to_archives = {}

    for source, rel_asset in find_object_paths(xml_path, exclude_thor=exclude_thor):
        archives = get_resource_manager().find_archives("objects", source, [rel_asset])
        if source not in source_to_archives:
            source_to_archives[source] = archives
        else:
            source_to_archives[source].extend(archives)

    source_to_archives = {
        source: list(set(archives)) for source, archives in source_to_archives.items()
    }

    get_resource_manager().install_packages("objects", source_to_archives)

    return source_to_archives

install_scene_from_path

install_scene_from_path(xml_path)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_scene_from_path(xml_path):
    scene_source = Path(xml_path).relative_to(get_scenes_root()).parts[0]

    rel_path = Path(xml_path).relative_to(get_scenes_root() / scene_source)
    archives = get_resource_manager().find_archives("scenes", scene_source, [rel_path])

    if not archives:
        raise RuntimeError(
            f"BUG: could not find archive for {xml_path} (relative {rel_path} for {scene_source})"
        )

    source_to_paths = {scene_source: archives}

    get_resource_manager().install_packages("scenes", source_to_paths)

    return source_to_paths

install_scene_from_source_index

install_scene_from_source_index(source, idx)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_scene_from_source_index(source, idx):
    archives = get_resource_manager().index_lookup("scenes", source, str(idx))
    if len(archives) == 0:
        raise ValueError(f"{source=} {idx=} returned {len(archives)} archives (expected 1)")
    assert len(archives) == 1, f"{source=} {idx=} returned {len(archives)} archives (expected 1)"
    source_to_paths = {source: archives}
    get_resource_manager().install_packages("scenes", source_to_paths)
    return source_to_paths

install_scene_with_objects_and_grasps_from_path

install_scene_with_objects_and_grasps_from_path(xml_path, grasp_sources=('droid_objaverse',), exclude_thor=True)
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_scene_with_objects_and_grasps_from_path(
    xml_path, grasp_sources=("droid_objaverse",), exclude_thor=True
):
    type_to_source_to_archives = {
        "scenes": install_scene_from_path(xml_path),
    }

    if not get_resource_manager().cache_lock:
        # We just need to link the scene, no need to check for objects or grasps
        # (everything is pre-cached, and we use a global symlink for those data types)
        return type_to_source_to_archives

    type_to_source_to_archives["objects"] = install_objects_for_scene(
        xml_path, exclude_thor=exclude_thor
    )

    type_to_source_to_archives["grasps"] = {}
    for grasp_source in grasp_sources:
        type_to_source_to_archives["grasps"].update(
            install_grasps_for_scene(xml_path, grasp_source=grasp_source, exclude_thor=exclude_thor)
        )

    return type_to_source_to_archives

install_uid

install_uid(uid, grasp_source='droid_objaverse')
Source code in molmo_spaces/utils/lazy_loading_utils.py
def install_uid(uid, grasp_source="droid_objaverse"):
    source, package, xml_path = locate_uid_package(uid)

    if source is None:
        raise ValueError(
            f"{uid} not found in object sources {sorted(DATA_TYPE_TO_SOURCE_TO_VERSION['objects'].keys())}"
        )

    if source in USER_ASSET_LIBRARIES:
        return xml_path

    if source != "thor":
        get_resource_manager().install_packages("objects", {source: [package]})

        # Install grasps (on-demand for objaverse)
        source_to_archives = {grasp_source: set()}
        for substr in split_query_tokens(Path(xml_path.name).stem):
            source_to_archives[grasp_source].update(
                get_resource_manager().index_lookup("grasps", grasp_source, substr)
            )
        get_resource_manager().install_packages("grasps", source_to_archives)

    return xml_path

locate_uid_package

locate_uid_package(uid: str, extension: str = 'xml') -> tuple[str, str | None, Path] | tuple[None, None, None]

Locate the package containing the given object UID.

Parameters:

Name Type Description Default
uid str

The UID of the object to locate.

required
extension str

The extension of the file to locate.

'xml'

Returns:

Type Description
tuple[str, str | None, Path] | tuple[None, None, None]

A tuple containing the source, package, and XML path of the object. If the object is not found, returns (None, None, None).

Source code in molmo_spaces/utils/lazy_loading_utils.py
def locate_uid_package(
    uid: str,
    extension: str = "xml",
) -> tuple[str, str | None, Path] | tuple[None, None, None]:
    """
    Locate the package containing the given object UID.

    Args:
        uid: The UID of the object to locate.
        extension: The extension of the file to locate.

    Returns:
        A tuple containing the source, package, and XML path of the object.
            If the object is not found, returns (None, None, None).
    """

    # Since thor objects are always fully installed, we just search in the file system
    thor_uid_to_xmls = get_thor_uid_to_xmls()

    if uid in thor_uid_to_xmls:
        base = (ASSETS_DIR / "objects" / "thor").resolve()
        xml_path = thor_uid_to_xmls[uid]
        return "thor", None, add_install_prefixes("objects", "thor", xml_path.relative_to(base))

    file_name = f"{uid}.{extension}"

    # For other sources (aka Objaverse for now), we need to search through
    # the data tries from the resource manager
    for object_source in sorted(DATA_TYPE_TO_SOURCE_TO_VERSION["objects"].keys()):
        if object_source in ["thor"]:
            continue

        substrings = split_query_tokens(uid)
        for substring in substrings:
            possible_archives = get_resource_manager().index_lookup(
                "objects", object_source, substring
            )
            if not possible_archives:
                continue

            # TODO pass archives to avoid full trie initialization? it's kind of fast, but...
            tries = get_resource_manager().tries("objects", object_source)
            for possible_archive in possible_archives:
                for path in tries.get(possible_archive, {}).leaf_paths():
                    if path.endswith(file_name):
                        return (
                            object_source,
                            possible_archive,
                            add_install_prefixes("objects", object_source, path),
                        )

    for user_library_name, user_library_dir in USER_ASSET_LIBRARIES.items():
        user_library_index = get_user_library_index(user_library_dir)
        if uid in user_library_index:
            # Assume the user is getting the object
            return user_library_name, None, user_library_dir / user_library_index[uid].object_path

    return None, None, None

lemma_utils

Functions:

Name Description
best_lemma_via_specificity
is_physical_entity
normalize_expression
simple_lemma

Attributes:

Name Type Description
PHYSICAL_ENTITY_SYNSET

PHYSICAL_ENTITY_SYNSET module-attribute

PHYSICAL_ENTITY_SYNSET = synset('physical_entity.n.01')

best_lemma_via_specificity cached

best_lemma_via_specificity(synset_str: str, enforce_physical_entity: bool = True) -> str
Source code in molmo_spaces/utils/lemma_utils.py
@functools.lru_cache(maxsize=1000)
def best_lemma_via_specificity(synset_str: str, enforce_physical_entity: bool = True) -> str:
    synset = wn.synset(synset_str)
    cur_synset_is_physical_entity = is_physical_entity(synset)
    min_num_synsets = 100000
    best_lemma = None
    for ln in synset.lemma_names():
        if cur_synset_is_physical_entity or enforce_physical_entity:
            num_synsets = len([s for s in wn.synsets(ln, pos=wn.NOUN) if is_physical_entity(s)])
        else:
            num_synsets = len(wn.synsets(ln, pos=wn.NOUN))
        if 0 < num_synsets < min_num_synsets:
            min_num_synsets = num_synsets
            best_lemma = ln

    if enforce_physical_entity and best_lemma is None:
        best_lemma = best_lemma_via_specificity(synset_str, enforce_physical_entity=False)

    assert best_lemma is not None, f"Failed to find lemma for {synset_str}"

    return best_lemma

is_physical_entity

is_physical_entity(synset: Synset | str) -> bool
Source code in molmo_spaces/utils/lemma_utils.py
def is_physical_entity(synset: Synset | str) -> bool:
    if isinstance(synset, str):
        synset = wn.synset(synset)
    return PHYSICAL_ENTITY_SYNSET in synset.lowest_common_hypernyms(PHYSICAL_ENTITY_SYNSET)

normalize_expression

normalize_expression(text: str) -> str
Source code in molmo_spaces/utils/lemma_utils.py
def normalize_expression(text: str) -> str:
    if ".n." in text:
        text = simple_lemma(text)

    return text.strip().lower().replace("_", " ").strip().strip(".;/,'\"\\")

simple_lemma

simple_lemma(synset_str: str) -> str
Source code in molmo_spaces/utils/lemma_utils.py
def simple_lemma(synset_str: str) -> str:
    return wn.synset(synset_str).lemma_names()[0]

license_utils

Functions:

Name Description
grasp_targets
ithor_resolver
procthor_resolver
resolve_grasps_license
resolve_license
resolve_object_license
resolve_robot_license
resolve_scene_license
scene_includes
scene_path_resolve
validate_identifier
validate_objaverse_identifier
validate_thor_identifier

Attributes:

Name Type Description
ATTRIBUTION_TEMPLATE
DEFAULT_LICENSE
ROBOT_LICENSE

ATTRIBUTION_TEMPLATE module-attribute

ATTRIBUTION_TEMPLATE = '{assets}' + f' by the {DEFAULT_LICENSE['creator_name']}, licensed under {replace('-', ' ')}.'

DEFAULT_LICENSE module-attribute

DEFAULT_LICENSE = {'license': 'CC-BY-4.0', 'license_url': 'https://creativecommons.org/licenses/by/4.0/', 'creator_name': 'Allen Institute for AI (Ai2)', 'source': 'In-house'}

ROBOT_LICENSE module-attribute

ROBOT_LICENSE = {}

grasp_targets

grasp_targets(data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def grasp_targets(data_source, identifier):
    info = get_resource_manager().source_info("grasps", data_source, recursive=True)
    archives = info["archive_to_relative_paths"].keys()
    archive = [archive for archive in archives if identifier in archive]

    if len(archive) == 0:
        raise ValueError(f"No archives for `grasps` {data_source} {identifier}")

    assert len(archive) == 1, f"Error: multiple archives for `grasps` {data_source} {identifier}"

    get_resource_manager().install_packages_bulk("grasps", {data_source: archive})

    targets = [
        str(path).split("/")[-1].split("_grasps_")[0]
        for path in info["archive_to_relative_paths"][archive[0]]
        if (
            str(path).endswith("_grasps_filtered.npz")
            or str(path).endswith("_grasps_filtered.json")
        )
    ]
    targets = [
        (f"{identifier}_" + target.replace(identifier, "").split("_")[:2][-1]).strip("_")
        for target in targets
    ]
    objaverse_targets = [
        target
        for target in targets
        if (ObjectMeta.annotation(target) or {}).get("isObjaverse", False)
    ]
    thor_targets = [
        target
        for target in targets
        if not (ObjectMeta.annotation(target) or {}).get("isObjaverse", True)
    ]
    ithor_targets = [
        target
        for target in targets
        if target not in objaverse_targets and target not in thor_targets
    ]
    ret = {
        "objaverse": [
            {
                "identifier": target,
                "attribution": resolve_object_license("objaverse", target)["attribution"],
            }
            for target in sorted(objaverse_targets)
        ],
        "thor": [
            {
                "identifier": target,
                "attribution": ATTRIBUTION_TEMPLATE.format(assets="Model(s)"),
            }
            for target in sorted(thor_targets)
        ],
        "ithor builtin": [
            {
                "identifier": target + " (ithor scene builtin)",
                "attribution": ATTRIBUTION_TEMPLATE.format(assets="Asset(s)"),
            }
            for target in sorted(ithor_targets)
        ],
    }

    return {key: value for key, value in ret.items() if value}

ithor_resolver

ithor_resolver(source: str, idx: int, scene_info: SourceInfo, modalities: list[Path], variant: str = '') -> Path
Source code in molmo_spaces/utils/license_utils.py
def ithor_resolver(
    source: str, idx: int, scene_info: SourceInfo, modalities: list[Path], variant: str = ""
) -> Path:
    return scene_info["root_dir"] / modalities[0]

procthor_resolver

procthor_resolver(source: str, idx: int, scene_info: SourceInfo, modalities: list[Path], variant: str = '_ceiling') -> Path
Source code in molmo_spaces/utils/license_utils.py
def procthor_resolver(
    source: str, idx: int, scene_info: SourceInfo, modalities: list[Path], variant: str = "_ceiling"
) -> Path:
    split = source.split("-")[-1]
    target = f"{split}_{idx}{variant}.xml"
    if Path(target) not in modalities:
        target = f"{split}_{idx}.xml"
    assert Path(target) in modalities, f"Missing {target=} with available modalities {modalities}"
    return scene_info["root_dir"] / target

resolve_grasps_license

resolve_grasps_license(data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def resolve_grasps_license(data_source, identifier):
    identifier = validate_identifier("grasps", data_source, identifier)

    license = {
        "data_type": "grasps",
        "data_source": data_source,
        "asset_id": identifier,
        **DEFAULT_LICENSE,
        "attribution": ATTRIBUTION_TEMPLATE.format(assets="Grasps generated"),
        "scope": "Grasp poses, collision data, and metadata (not the underlying 3D model or object).",
        "relationship_to_assets": "annotation",
        "asset_licenses": "Underlying objects are independently licensed; see asset license info for details.",
        "license_determination": (
            "Grasp data and metadata are derived or created independently of the assets to which they apply."
            f" {DEFAULT_LICENSE['license']} applies only to grasp data and metadata, not object meshes, textures, or other"
            " underlying assets."
        ),
        "external_target_assets": grasp_targets(data_source, identifier),
    }

    return license

resolve_license

resolve_license(data_type, data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def resolve_license(data_type, data_source, identifier):
    if data_type == "objects":
        return resolve_object_license(data_source, identifier)
    if data_type == "scenes":
        return resolve_scene_license(data_source, identifier)
    if data_type == "grasps":
        return resolve_grasps_license(data_source, identifier)
    if data_type == "robots":
        return resolve_robot_license(data_source, identifier)

    raise ValueError(f"Non-valid {data_type=}")

resolve_object_license

resolve_object_license(data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def resolve_object_license(data_source, identifier):
    if data_source == "objaverse":
        anno = ObjectMeta.annotation(validate_objaverse_identifier(identifier))
        lic = anno["license_info"]

        assert "sketchfab" in lic["creator_profile_url"], (
            f"Only sketchfab assets expected, got {lic['creator_profile_url']}"
        )

        cur_license = {
            "data_type": "objects",
            "data_source": data_source,
            "asset_id": anno["assetId"],
            "creator_username": lic["creator_username"],
            "creator_display_name": lic["creator_display_name"],
            "creator_profile_url": lic["creator_profile_url"],
            "source": "Sketchfab",
            "uri": lic["uri"],
            "downloaded": "2021-2022",
            "license_determination": "License inferred from Sketchfab designation at time"
            " of download (circa 2021/2022).",
            "modifications": "The model has been significantly modified to reduce memory and processing requirements,"
            " including mesh decimation, convex collider extraction, and baking of visual effects via Blender scripts."
            " The provided quality may not reflect the original model.",
            "dataset_license": "This subset of Objaverse is licensed under ODC-BY 1.0.",
        }

        if lic["license"] == "by":
            cur_license["license"] = "CC-BY-4.0"
            cur_license["license_url"] = "https://creativecommons.org/licenses/by/4.0/"
            cur_license["derivative_notice"] = "This work is a derivative of the original model."
            cur_license["attribution"] = (
                f"Model by {lic['creator_display_name']} ({lic['creator_username']}), licensed under CC BY 4.0."
            )
        elif lic["license"] == "by-sa":
            cur_license["license"] = "CC-BY-SA-4.0"
            cur_license["license_url"] = "https://creativecommons.org/licenses/by-sa/4.0/"
            cur_license["derivative_license"] = (
                "This derivative work is licensed under CC BY-SA 4.0."
            )
            cur_license["attribution"] = (
                f"Model by {lic['creator_display_name']} ({lic['creator_username']}), licensed under CC BY-SA 4.0."
            )
        elif lic["license"] == "cc0":
            cur_license["license"] = "CC0-1.0"
            cur_license["license_url"] = "https://creativecommons.org/publicdomain/zero/1.0/"
            cur_license["derivative_notice"] = (
                "This work is a derivative of the original asset, which was released under CC0."
            )
            cur_license["attribution"] = (
                f"Model by {lic['creator_display_name']} ({lic['creator_username']}), licensed under CC0-1.0."
            )
        elif lic["license"] == "by-nc":
            cur_license["license"] = "CC-BY-NC-4.0"
            cur_license["license_url"] = "https://creativecommons.org/licenses/by-nc/4.0/"
            cur_license["commercial_use"] = False
            cur_license["derivative_notice"] = (
                "This work is a derivative of the original asset and may not be used for commercial purposes."
            )
            cur_license["attribution"] = (
                f"Model by {lic['creator_display_name']} ({lic['creator_username']}), licensed under CC BY-NC 4.0."
                f" Non-commercial use only."
            )
        elif lic["license"] == "by-nc-sa":
            cur_license["license"] = "CC-BY-NC-SA-4.0"
            cur_license["license_url"] = "https://creativecommons.org/licenses/by-nc-sa/4.0/"
            cur_license["commercial_use"] = False
            cur_license["derivative_license"] = (
                "This derivative work is licensed under CC BY-NC-SA 4.0."
            )
            cur_license["derivative_notice"] = (
                "This work is a derivative of the original asset and may not be used for commercial purposes."
            )
            cur_license["attribution"] = (
                f"Model by {lic['creator_display_name']} ({lic['creator_username']}), licensed under CC BY-NC-SA"
                f" 4.0. Non-commercial use only."
            )
        else:
            raise NotImplementedError(f"Got unsupported license {lic['license']}")

    elif data_source == "thor":
        cur_license = {
            "data_type": "objects",
            "data_source": data_source,
            "asset_id_or_archive_name": validate_thor_identifier(identifier),
            **DEFAULT_LICENSE,
            "attribution": ATTRIBUTION_TEMPLATE.format(assets="Model(s)"),
        }

    elif data_source == "objathor_metadata":
        assets = (
            "Object annotation (bounding boxes, masses, synsets, CLIP features, etc.) extracted"
        )

        return {
            "data_type": "objects",
            "data_source": "objathor_metadata",
            "asset_id": identifier,
            **DEFAULT_LICENSE,
            "attribution": ATTRIBUTION_TEMPLATE.format(assets=assets),
            "scope": "Annotation data derived from objects, including bounding boxes, physical properties,"
            " descriptions, CLIP embedding–based representations, and semantic labels.",
            "relationship_to_objects": "annotation",
            "object_licenses": "Underlying objects are independently licensed; see each object's license for details"
            " (under `objects` - `thor` and `objects` - `objaverse`).",
            "license_note": "This license applies only to the annotation data. Underlying objects are independently"
            " licensed and are not covered by this license.",
        }

    else:
        raise ValueError(
            f"Can't determine license for `objects` with {data_source=} and {identifier=}"
        )

    return cur_license

resolve_robot_license

resolve_robot_license(data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def resolve_robot_license(data_source, identifier):
    common_license = {
        "data_type": "robots",
        "data_source": data_source,
        "asset_id": identifier,
    }

    if "franka" in identifier:
        cur_license = {
            **common_license,
            "creator_username": "Franka Robotics",
            "attribution": "Developed by Franka Robotics",
            "source": "mujoco_menagerie/franka_fr3",
            "license": "Apache 2.0",
            "uri": "https://github.com/google-deepmind/mujoco_menagerie/blob/main/franka_fr3/LICENSE",
            "downloaded": "2025",
            "modifications": "Changed position controller to force controller for hand gripper",
        }
    elif "rby" in identifier:
        cur_license = {
            **common_license,
            "creator_username": "Rainbow Robotics",
            "attribution": "Copyright 2024-2025 Rainbow Robotics",
            "source": "RainbowRobotics/rby1-sdk",
            "license": "Apache 2.0",
            "uri": "https://github.com/RainbowRobotics/rby1-sdk/blob/main/LICENSE",
            "downloaded": "2025",
            "modifications": "Added holonomic base and removed wheel controller and slider controllers",
        }
    elif "robotiq" in identifier:
        cur_license = {
            **common_license,
            "creator_username": "ROS-Industrial",
            "attribution": "Copyright (c) 2013, ROS-Industrial",
            "source": "mujoco_menagerie/robotiq_2f85_v4",
            "license": "BSD-2-Clause License",
            "uri": "https://github.com/google-deepmind/mujoco_menagerie/blob/main/robotiq_2f85_v4/LICENSE",
            "downloaded": "2025",
        }
    elif "rum" in identifier:
        cur_license = {
            **common_license,
            "creator_username": "NYU Generalizable Robotics and AI Lab (GRAIL)",
            "attribution": "Copyright (c) 2026 NYU Generalizable Robotics and AI Lab (GRAIL)",
            "source": "jeffacce/cap-policy",
            "license": "MIT",
            "uri": "https://github.com/jeffacce/cap-policy/blob/main/LICENSE",
            "downloaded": "2025",
        }
    elif "yam" in identifier:
        cur_license = {
            **common_license,
            "creator_username": "I2RT Robotics, LLC",
            "attribution": "Copyright (c) I2RT Robotics",
            "source": "https://github.com/i2rt-robotics/i2rt/blob/d36027fc50e12d9261f091f9d91c4715bb5e398f/i2rt/robots/get_robot.py#L129",
            "license": "MIT",
            "uri": "https://github.com/i2rt-robotics/i2rt/blob/main/LICENSE",
            "downloaded": "2026",
        }
    else:
        raise NotImplementedError(f"Got unknown robot {identifier}")

    return cur_license

resolve_scene_license

resolve_scene_license(data_source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def resolve_scene_license(data_source, identifier):
    original_identifier = identifier

    if isinstance(identifier, str):
        match = re.search(r"\d+", identifier)
        if match:
            identifier = int(match.group())

    try:
        identifier = int(identifier)
        is_idx = True
    except ValueError:
        is_idx = False

    if is_idx:
        source_to_archives = install_scene_from_source_index(data_source, identifier)
        scene_path = scene_path_resolve(data_source, identifier, source_to_archives)
        if isinstance(original_identifier, str) and original_identifier not in str(scene_path):
            raise ValueError(f"Non-valid identifier {original_identifier}")
        includes = scene_includes(scene_path)

    else:
        archives = get_resource_manager().tries("scenes", data_source).keys()
        archive = [archive for archive in archives if identifier in archive]

        if len(archive) == 0:
            raise ValueError(f"No archives for `scenes` {data_source} {identifier}")

        assert len(archive) == 1, (
            f"Error: multiple archives for `scenes` {data_source} {identifier}"
        )

        get_resource_manager().install_packages_bulk("scenes", {data_source: archive})

        includes = []

    scene_license = {
        "data_type": "scenes",
        "data_source": data_source,
        "asset_id": str(identifier),
        **DEFAULT_LICENSE,
        "attribution": ATTRIBUTION_TEMPLATE.format(assets="Scene"),
        "scope": "Scene composition, layout, non-object-specific textures, and metadata.",
        "relationship_to_assets": "collection",
        "asset_licenses": "Assets are independently licensed; see assets info below for details.",
        "license_determination": "Scenes are collections referencing independently licensed assets;"
        f" {DEFAULT_LICENSE['license']} applies only to scene composition, layout, and metadata.",
    }
    if includes:
        scene_license["assets"] = includes

    return scene_license

scene_includes

scene_includes(scene_path)
Source code in molmo_spaces/utils/license_utils.py
def scene_includes(scene_path):
    def identifier_and_asset_part(rel_asset):
        archives = get_resource_manager().find_archives("objects", source, [rel_asset])

        assert len(archives) == 1, (
            f"Expected exactly one archive for {rel_asset}, got {len(archives)}"
        )

        return validate_identifier("objects", source, archives[0]), rel_asset.name

    includes = defaultdict(set)
    identifier_to_objs = defaultdict(set)
    for source, rel_asset in find_object_paths(scene_path, exclude_thor=False):
        identifier, obj = identifier_and_asset_part(rel_asset)
        identifier_to_objs[identifier].add(obj)
        includes[source].add(identifier)

    ret: dict[str, list[dict[str, Any]]] = {
        source: [
            {
                "identifier": identifier,
                # "includes": sorted(identifier_to_objs[identifier]),
                "attribution": resolve_object_license(source, identifier)["attribution"],
            }
            for identifier in includes[source]
        ]
        for source in sorted(includes.keys())
    }

    return ret

scene_path_resolve

scene_path_resolve(source: str, idx: int, source_to_archives: dict[str, Collection[str]]) -> Path
Source code in molmo_spaces/utils/license_utils.py
def scene_path_resolve(
    source: str, idx: int, source_to_archives: dict[str, Collection[str]]
) -> Path:
    if "procthor" in source or "holodeck" in source:
        fn = procthor_resolver
    elif "ithor" in source:
        fn = ithor_resolver
    else:
        raise NotImplementedError(f"Missing implementation for {source}")

    archive = list(source_to_archives[source])[0]
    scene_info = get_resource_manager().source_info("scenes", source, recursive=False)
    modalities = scene_info["archive_to_relative_paths"][archive]
    return fn(source, idx, scene_info, modalities)

validate_identifier

validate_identifier(data_type, source, identifier)
Source code in molmo_spaces/utils/license_utils.py
def validate_identifier(data_type, source, identifier):
    archives = get_resource_manager().tries(data_type, source).keys()
    for archive in archives:
        if identifier in archive:
            break
    else:
        raise ValueError(f"{identifier=} is not in {source=} ({data_type=})")
    return archive.split(".")[0].replace(f"{source}_", "")

validate_objaverse_identifier

validate_objaverse_identifier(identifier)
Source code in molmo_spaces/utils/license_utils.py
def validate_objaverse_identifier(identifier):
    anno = ObjectMeta.annotation(identifier)
    if anno is not None:
        if not anno["isObjaverse"]:
            raise ValueError(f"{identifier=} is not in `objaverse`")
        return identifier

    return validate_identifier("objects", "objaverse", identifier)

validate_thor_identifier

validate_thor_identifier(identifier)
Source code in molmo_spaces/utils/license_utils.py
def validate_thor_identifier(identifier):
    anno = ObjectMeta.annotation(identifier)
    if anno is not None:
        if anno["isObjaverse"]:
            raise ValueError(f"{identifier=} is not in `thor`")
        return identifier

    return validate_identifier("objects", "thor", identifier)

linalg_utils

Functions:

Name Description
euler_yaw_to_quat

Convert euler (0, 0, yaw) to quat (w, x, y, z)

global_to_relative_transform
homogenize

Project a vector to homogenous coordinates. Accepts either a single vector or a batch.

interp

Linear interpolation of vector-valued functions of scalars. Similar to np.interp but for multi-dimensional arrays.

inverse_homogeneous_matrix

Compute the inverse of a 4x4 homogeneous transformation matrix.

normalize_ang_error
obb_2d

Compute the oriented bounding box (OBB) of a set of 2D points.

quat_to_euler_yaw

Convert quaternion (w, x, y, z) to euler yaw (radians)

relative_to_global_transform
single_or_batch

Decorator to allow a function to accept a single input or a batch of inputs.

skew

Compute the skew-symmetric matrix of a 3D vector.

swing_twist

Decomposes quat into a rotation around axis and a rotation around an

transform_to_twist

Given a 4x4 transformation matrix, return the twist as (lin_vel, ang_vel).

twist_to_transform

Given a linear velocity and angular velocity, return the 4x4 transformation matrix.

euler_yaw_to_quat

euler_yaw_to_quat(yaw)

Convert euler (0, 0, yaw) to quat (w, x, y, z)

Source code in molmo_spaces/utils/linalg_utils.py
def euler_yaw_to_quat(yaw):
    """
    Convert euler (0, 0, yaw) to quat (w, x, y, z)
    """
    return R.from_euler("xyz", [0, 0, yaw], degrees=False).as_quat(scalar_first=True)

global_to_relative_transform

global_to_relative_transform(x, base)
Source code in molmo_spaces/utils/linalg_utils.py
def global_to_relative_transform(x, base):
    return inverse_homogeneous_matrix(base) @ x

homogenize

homogenize(x: ndarray)

Project a vector to homogenous coordinates. Accepts either a single vector or a batch.

Source code in molmo_spaces/utils/linalg_utils.py
@single_or_batch
def homogenize(x: np.ndarray):
    """
    Project a vector to homogenous coordinates. Accepts either a single vector or a batch.
    """
    assert x.ndim == 2
    return np.hstack([x, np.ones((x.shape[0], 1))])

interp

interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | None = None, right: ArrayLike | None = None)

Linear interpolation of vector-valued functions of scalars. Similar to np.interp but for multi-dimensional arrays.

Source code in molmo_spaces/utils/linalg_utils.py
def interp(
    x: ArrayLike,
    xp: ArrayLike,
    fp: ArrayLike,
    left: ArrayLike | None = None,
    right: ArrayLike | None = None,
):
    """
    Linear interpolation of vector-valued functions of scalars. Similar to np.interp but for multi-dimensional arrays.
    """
    x = np.asarray(x)
    is_batch = x.ndim > 0
    x = x.reshape(-1)
    xp = np.asarray(xp)
    fp = np.asarray(fp)
    if len(fp.shape) == 1:
        fp = fp.reshape(-1, 1)
    assert len(xp.shape) == 1 and xp.shape[0] == fp.shape[0]

    # Handle out of bounds
    ret = np.zeros((x.shape[0], fp.shape[-1]), fp.dtype)
    lt_mask = x <= xp[0]
    gt_mask = x > xp[-1]
    if np.any(lt_mask):
        ret[lt_mask] = left if left is not None else fp[0]
    if np.any(gt_mask):
        ret[gt_mask] = right if right is not None else fp[-1]

    in_bounds_mask = ~lt_mask & ~gt_mask
    x_in_bounds = x[in_bounds_mask]
    i = np.searchsorted(xp, x_in_bounds)

    x0, x1 = xp[i - 1], xp[i]
    f0, f1 = fp[i - 1], fp[i]
    ret[in_bounds_mask] = f0 + (f1 - f0) / (x1 - x0)[:, None] * (x_in_bounds - x0)[:, None]
    return ret if is_batch else ret[0]

inverse_homogeneous_matrix

inverse_homogeneous_matrix(matrix: ndarray)

Compute the inverse of a 4x4 homogeneous transformation matrix.

Args: matrix (numpy.ndarray): A 4x4 homogeneous transformation matrix.

Returns: numpy.ndarray: The inverse of the input matrix.

Source code in molmo_spaces/utils/linalg_utils.py
def inverse_homogeneous_matrix(matrix: np.ndarray):
    """
    Compute the inverse of a 4x4 homogeneous transformation matrix.

    Args:
    matrix (numpy.ndarray): A 4x4 homogeneous transformation matrix.

    Returns:
    numpy.ndarray: The inverse of the input matrix.
    """
    if matrix.shape != (4, 4):
        raise ValueError("Input matrix must be a 4x4 matrix.")

    rotation_matrix = matrix[0:3, 0:3]
    translation_vector = matrix[0:3, 3]

    inverse_rotation = np.transpose(rotation_matrix)
    inverse_translation = -np.dot(inverse_rotation, translation_vector)

    inverse_matrix = np.identity(4)
    inverse_matrix[0:3, 0:3] = inverse_rotation
    inverse_matrix[0:3, 3] = inverse_translation
    return inverse_matrix

normalize_ang_error

normalize_ang_error(ang)
Source code in molmo_spaces/utils/linalg_utils.py
def normalize_ang_error(ang):
    # Normalize to [-pi, pi] range
    ang = (ang + np.pi) % (2 * np.pi) - np.pi
    return ang

obb_2d

obb_2d(points: ndarray) -> tuple[ndarray, ndarray, ndarray]

Compute the oriented bounding box (OBB) of a set of 2D points. Parameters: points (np.ndarray): A 2D numpy array of shape (N, 2) representing the coordinates of the points.

tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: - pos (np.ndarray): The center position of the OBB. - minor_axis (np.ndarray): The minor axis of the OBB, i.e. half the shorter side. - major_axis (np.ndarray): The major axis of the OBB, i.e. half the longer side.

Source code in molmo_spaces/utils/linalg_utils.py
def obb_2d(points: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute the oriented bounding box (OBB) of a set of 2D points.
    Parameters:
    points (np.ndarray): A 2D numpy array of shape (N, 2) representing the coordinates of the points.

    Returns:
    tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
        - pos (np.ndarray): The center position of the OBB.
        - minor_axis (np.ndarray): The minor axis of the OBB, i.e. half the shorter side.
        - major_axis (np.ndarray): The major axis of the OBB, i.e. half the longer side.
    """
    points = np.asarray(points)
    hull = ConvexHull(points)
    hull_points = points[hull.vertices]

    edges = np.diff(np.concatenate([hull_points, hull_points[:1]], axis=0), axis=0)
    x_axes = edges / np.linalg.norm(edges, axis=1)[:, None]
    y_axes = np.column_stack([-x_axes[:, 1], x_axes[:, 0]])
    rotmats = np.stack([x_axes, y_axes], axis=2)

    rot_points = np.expand_dims(points, axis=0) @ rotmats.transpose(0, 2, 1)
    rot_merged_mins = np.min(rot_points, axis=1)
    rot_merged_maxs = np.max(rot_points, axis=1)
    areas = np.prod(rot_merged_maxs - rot_merged_mins, axis=1)
    best_bbox_idx = np.argmin(areas)

    rotmat = rotmats[best_bbox_idx]
    rot_merged_min = rot_merged_mins[best_bbox_idx]
    rot_merged_max = rot_merged_maxs[best_bbox_idx]
    pos = rotmat.T @ (rot_merged_min + rot_merged_max) / 2
    half_size = (rot_merged_max - rot_merged_min) / 2
    minor_axis, major_axis = sorted(rotmat * half_size.reshape(-1, 1), key=np.linalg.norm)
    best_box = (pos, minor_axis, major_axis)
    return best_box

quat_to_euler_yaw

quat_to_euler_yaw(quat)

Convert quaternion (w, x, y, z) to euler yaw (radians)

Source code in molmo_spaces/utils/linalg_utils.py
def quat_to_euler_yaw(quat):
    """
    Convert quaternion (w, x, y, z) to euler yaw (radians)
    """
    return R.from_quat(quat, scalar_first=True).as_euler("xyz", degrees=False)[2]

relative_to_global_transform

relative_to_global_transform(x, base)
Source code in molmo_spaces/utils/linalg_utils.py
def relative_to_global_transform(x, base):
    return base @ x

single_or_batch

single_or_batch(func)

Decorator to allow a function to accept a single input or a batch of inputs. The decorated function should always accept and return batches.

Source code in molmo_spaces/utils/linalg_utils.py
def single_or_batch(func):
    """
    Decorator to allow a function to accept a single input or a batch of inputs.
    The decorated function should always accept and return batches.
    """

    @wraps(func)
    def wrapper(*args, **kwargs):
        idx = 1 if len(args) > 0 and hasattr(args[0], "__dict__") else 0
        x = np.asarray(args[idx])
        if not_batch := x.ndim == 1:
            x = x.reshape(1, -1)
        ret = func(*args[:idx], x, *args[idx + 1 :], **kwargs)
        return ret[0] if not_batch else ret

    return wrapper

skew

skew(v: ndarray)

Compute the skew-symmetric matrix of a 3D vector.

Source code in molmo_spaces/utils/linalg_utils.py
def skew(v: np.ndarray):
    """
    Compute the skew-symmetric matrix of a 3D vector.
    """
    return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])

swing_twist

swing_twist(quat: ndarray, axis: ndarray)

Decomposes quat into a rotation around axis and a rotation around an axis perpendicular to axis.

Note: Assumes quaternions are [w,x,y,z]

Returns quaternions (swing, twist) where quat = swing * twist, and twist is a rotation around axis

Source code in molmo_spaces/utils/linalg_utils.py
def swing_twist(quat: np.ndarray, axis: np.ndarray):
    """
    Decomposes quat into a rotation around axis and a rotation around an
    axis perpendicular to axis.

    Note: Assumes quaternions are [w,x,y,z]

    Returns quaternions (swing, twist) where quat = swing * twist, and
    twist is a rotation around axis
    """
    axis = axis.astype(np.float64) / np.linalg.norm(axis)
    rot_ax = quat[1:]
    p = np.dot(rot_ax, axis) * axis
    twist = np.hstack([quat[:1], p])
    twist /= np.linalg.norm(twist)
    quat_rot = R.from_quat(quat, scalar_first=True)
    twist_rot = R.from_quat(twist, scalar_first=True)
    swing = (quat_rot * twist_rot.inv()).as_quat(scalar_first=True)
    return swing, twist

transform_to_twist

transform_to_twist(T: ndarray)

Given a 4x4 transformation matrix, return the twist as (lin_vel, ang_vel). Mathematically, this is computing the logarithmic map of SE(3). Equivalent to pin.log6.

See: https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf (Sec 9.4.2)

Source code in molmo_spaces/utils/linalg_utils.py
def transform_to_twist(T: np.ndarray):
    """
    Given a 4x4 transformation matrix, return the twist as (lin_vel, ang_vel).
    Mathematically, this is computing the logarithmic map of SE(3). Equivalent to pin.log6.

    See: https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf (Sec 9.4.2)
    """
    w = R.from_matrix(T[:3, :3]).as_rotvec()
    theta = np.linalg.norm(w)
    if np.abs(theta) < 1e-6:
        return T[:3, 3], w
    V = (
        np.eye(3)
        + (1 - np.cos(theta)) / theta**2 * skew(w)
        + (theta - np.sin(theta)) / theta**3 * np.dot(skew(w), skew(w))
    )
    t = np.linalg.solve(V, T[:3, 3])
    return t, w

twist_to_transform

twist_to_transform(lin_vel: ndarray, ang_vel: ndarray)

Given a linear velocity and angular velocity, return the 4x4 transformation matrix. Mathematically, this is computing the exponential map of SE(3). Equivalent to pin.exp6.

See: https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf (Sec 9.4.2)

Source code in molmo_spaces/utils/linalg_utils.py
def twist_to_transform(lin_vel: np.ndarray, ang_vel: np.ndarray):
    """
    Given a linear velocity and angular velocity, return the 4x4 transformation matrix.
    Mathematically, this is computing the exponential map of SE(3). Equivalent to pin.exp6.

    See: https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf (Sec 9.4.2)
    """
    theta = np.linalg.norm(ang_vel)
    T = np.eye(4)
    T[:3, :3] = R.from_rotvec(ang_vel).as_matrix()
    if np.abs(theta) < 1e-6:
        V = np.eye(3)
    else:
        V = (
            np.eye(3)
            + (1 - np.cos(theta)) / theta**2 * skew(ang_vel)
            + (theta - np.sin(theta)) / theta**3 * np.dot(skew(ang_vel), skew(ang_vel))
        )
    T[:3, 3] = V @ lin_vel
    return T

mj_model_and_data_utils

Functions:

Name Description
body_aabb

Computes the axis-aligned bounding box (AABB) for a body in a MuJoCo model.

body_base_pos

Returns the base position of a body in the world frame.

body_pose
descendant_bodies

Get all bodies descended from a body in a MuJoCo model.

descendant_geoms

Get all geoms attached to descendants of a body in a MuJoCo model.

extract_mj_names

See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1127

geom_aabb

Computes the axis-aligned bounding box (AABB) for a list of geometries in a MuJoCo model.

mesh_aabb

Compute the tight AABB in world space for a mesh geom using its vertices.

site_pose

body_aabb

body_aabb(model: MjModel, data: MjData, body_id: int, visible_only: bool = True) -> tuple[ndarray, ndarray]

Computes the axis-aligned bounding box (AABB) for a body in a MuJoCo model.

Parameters:

Name Type Description Default
model MjModel

The MuJoCo model containing the body.

required
data MjData

The MuJoCo data containing the state of the model.

required
body_id int

The id of the body to compute the AABB for.

required
visible_only bool

Whether to only include visible geoms (groups 0-2). This can help make the AABB fit tighter.

True

Returns:

Name Type Description
tuple tuple[ndarray, ndarray]

A tuple containing: - numpy.ndarray: The center of the AABB in world space. - numpy.ndarray: The x,y,z dimensions of the AABB.

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def body_aabb(
    model: mujoco.MjModel, data: mujoco.MjData, body_id: int, visible_only: bool = True
) -> tuple[np.ndarray, np.ndarray]:
    """
    Computes the axis-aligned bounding box (AABB) for a body in a MuJoCo model.

    Args:
        model (mujoco.MjModel): The MuJoCo model containing the body.
        data (mujoco.MjData): The MuJoCo data containing the state of the model.
        body_id (int): The id of the body to compute the AABB for.
        visible_only (bool): Whether to only include visible geoms (groups 0-2).
            This can help make the AABB fit tighter.

    Returns:
        tuple: A tuple containing:
            - numpy.ndarray: The center of the AABB in world space.
            - numpy.ndarray: The x,y,z dimensions of the AABB.
    """
    geoms = descendant_geoms(model, body_id, visible_only=visible_only)
    if not geoms:
        # If body has no geoms, return body position as center with zero extent
        return data.xpos[body_id].copy(), np.zeros(3)
    return geom_aabb(model, data, geoms)

body_base_pos

body_base_pos(data: MjData, body_id: int, visible_only: bool = True) -> ndarray

Returns the base position of a body in the world frame. In XY, this is the center of the AABB, and in Z, this is the bottom of the AABB.

Parameters:

Name Type Description Default
data MjData

MjData object

required
body_id int

ID of the body to get the base position of.

required
visible_only bool

Whether to only include visible geoms (groups 0-2). This can help make the AABB fit tighter.

True

Returns:

Type Description
ndarray

np.ndarray: The base position of the body in the world frame, of shape (3,).

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def body_base_pos(data: mujoco.MjData, body_id: int, visible_only: bool = True) -> np.ndarray:
    """
    Returns the base position of a body in the world frame.
    In XY, this is the center of the AABB, and in Z, this is the bottom of the AABB.

    Args:
        data: MjData object
        body_id: ID of the body to get the base position of.
        visible_only (bool): Whether to only include visible geoms (groups 0-2).
            This can help make the AABB fit tighter.

    Returns:
        np.ndarray: The base position of the body in the world frame, of shape (3,).
    """
    body_aabb_center, body_aabb_size = body_aabb(
        data.model, data, body_id, visible_only=visible_only
    )
    return body_aabb_center - np.array([0, 0, body_aabb_size[2] / 2])

body_pose

body_pose(data: MjData, body_id: int) -> ndarray
Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def body_pose(data: mujoco.MjData, body_id: int) -> np.ndarray:
    trf = np.eye(4)
    trf[:3, 3] = data.xpos[body_id]
    trf[:3, :3] = data.xmat[body_id].reshape(3, 3)
    return trf

descendant_bodies

descendant_bodies(model: MjModel, body_id: int)

Get all bodies descended from a body in a MuJoCo model.

Parameters:

Name Type Description Default
model MjModel

The MuJoCo model to use.

required
body_id int

The id of the body to get the descendants of.

required

Returns:

Type Description

set[int]: A set of the ids of the bodies descended from the body, including the body itself.

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def descendant_bodies(model: MjModel, body_id: int):
    """
    Get all bodies descended from a body in a MuJoCo model.

    Args:
        model (MjModel): The MuJoCo model to use.
        body_id (int): The id of the body to get the descendants of.

    Returns:
        set[int]: A set of the ids of the bodies descended from the body, including the body itself.
    """
    if body_id == 0:
        return set(range(model.nbody))

    descendants = {body_id}
    for bid in np.where(model.body_parentid == body_id)[0]:
        descendants.update(descendant_bodies(model, bid))
    return descendants

descendant_geoms

descendant_geoms(model: MjModel, body_id: int, visible_only: bool = True) -> list[int]

Get all geoms attached to descendants of a body in a MuJoCo model.

Parameters:

Name Type Description Default
model MjModel

The MuJoCo model to use.

required
body_id int

The id of the body to get the geoms of.

required
visible_only bool

Whether to only include visible geoms (groups 0-2).

True

Returns:

Type Description
list[int]

list[int]: A sorted list of the ids of the geoms attached to descendants of the body, or the body itself.

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def descendant_geoms(model: MjModel, body_id: int, visible_only: bool = True) -> list[int]:
    """
    Get all geoms attached to descendants of a body in a MuJoCo model.

    Args:
        model (MjModel): The MuJoCo model to use.
        body_id (int): The id of the body to get the geoms of.
        visible_only (bool): Whether to only include visible geoms (groups 0-2).

    Returns:
        list[int]: A sorted list of the ids of the geoms attached to descendants of the body, or the body itself.
    """
    bodies = np.array(list(descendant_bodies(model, body_id)))
    mask = np.any(model.geom_bodyid.reshape(1, -1) == bodies.reshape(-1, 1), axis=0)
    geoms = np.where(mask)[0]
    if visible_only:
        is_visible = model.geom_group[geoms] < 3
        geoms = geoms[is_visible]
    return geoms.tolist()

extract_mj_names

extract_mj_names(model, name_adr: ndarray | None, num_obj: int, obj_type: mjtObj)

See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1127

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def extract_mj_names(model, name_adr: np.ndarray | None, num_obj: int, obj_type: mjtObj):
    """
    See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1127
    """
    # objects don't need to be named in the XML, so name might be None
    id2name = {i: None for i in range(num_obj)}
    name2id = {}
    for i in range(num_obj):
        name = mujoco.mj_id2name(model, obj_type, i)
        name2id[name] = i
        id2name[i] = name

    # sort names by increasing id to keep order deterministic
    return tuple(id2name[nid] for nid in sorted(name2id.values())), name2id, id2name

geom_aabb

geom_aabb(model: MjModel, data: MjData, geom_ids: list[int], tight_mesh: bool = True) -> tuple[ndarray, ndarray]

Computes the axis-aligned bounding box (AABB) for a list of geometries in a MuJoCo model.

Parameters:

Name Type Description Default
model MjModel

The MuJoCo model containing the geometries.

required
data MjData

The MuJoCo data containing the state of the model.

required
geom_ids list[int]

A list of geometry IDs for which to compute the AABB.

required
tight_mesh bool

Whether to compute the tight AABB for mesh geoms. If False, the AABB will be computed using the geom_aabb field, and may not be tight in world space.

True

Returns:

Name Type Description
tuple tuple[ndarray, ndarray]

A tuple containing: - numpy.ndarray: The center of the merged AABB in world space. - numpy.ndarray: The x,y,z dimensions of the merged AABB.

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def geom_aabb(
    model: mujoco.MjModel, data: mujoco.MjData, geom_ids: list[int], tight_mesh: bool = True
) -> tuple[np.ndarray, np.ndarray]:
    """
    Computes the axis-aligned bounding box (AABB) for a list of geometries in a MuJoCo model.

    Args:
        model (mujoco.MjModel): The MuJoCo model containing the geometries.
        data (mujoco.MjData): The MuJoCo data containing the state of the model.
        geom_ids (list[int]): A list of geometry IDs for which to compute the AABB.
        tight_mesh (bool): Whether to compute the tight AABB for mesh geoms.
            If False, the AABB will be computed using the geom_aabb field, and may not be tight in world space.

    Returns:
        tuple: A tuple containing:
            - numpy.ndarray: The center of the merged AABB in world space.
            - numpy.ndarray: The x,y,z dimensions of the merged AABB.
    """
    if not geom_ids:
        # If no geoms provided, return zero-sized AABB at origin
        return np.zeros(3), np.zeros(3)

    vertices = []
    corners = np.array(list(itertools.product([-1.0, 1.0], repeat=3)))
    for geom_id in geom_ids:
        if tight_mesh and model.geom_type[geom_id] == mujoco.mjtGeom.mjGEOM_MESH.value:
            mesh_aabb_center, mesh_aabb_size = mesh_aabb(model, data, geom_id)
            vertices.append(mesh_aabb_center + corners * mesh_aabb_size / 2)
        else:
            geom_rotmat = data.geom_xmat[geom_id].reshape(3, 3)
            geom_pos = data.geom_xpos[geom_id]

            aabb = model.geom_aabb[geom_id]
            local_corners = aabb[:3] + corners * aabb[3:]
            world_corners = local_corners @ geom_rotmat.T + geom_pos
            vertices.append(world_corners)

    # merge aabbs
    vertices = np.concatenate(vertices, axis=0)
    merged_min = np.min(vertices, axis=0)
    merged_max = np.max(vertices, axis=0)
    return (merged_min + merged_max) / 2, merged_max - merged_min

mesh_aabb

mesh_aabb(model: MjModel, data: MjData, geom_id: int) -> tuple[ndarray, ndarray]

Compute the tight AABB in world space for a mesh geom using its vertices.

Parameters:

Name Type Description Default
model MjModel

The MuJoCo model containing the geom.

required
data MjData

The MuJoCo data containing the state of the model.

required
geom_id int

The id of the mesh geom to compute the AABB for. Must be a mesh geom.

required

Returns:

Name Type Description
tuple tuple[ndarray, ndarray]

A tuple containing: - numpy.ndarray: The center of the AABB in world space. - numpy.ndarray: The x,y,z dimensions of the AABB.

Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def mesh_aabb(
    model: mujoco.MjModel, data: mujoco.MjData, geom_id: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute the tight AABB in world space for a mesh geom using its vertices.

    Args:
        model (mujoco.MjModel): The MuJoCo model containing the geom.
        data (mujoco.MjData): The MuJoCo data containing the state of the model.
        geom_id (int): The id of the mesh geom to compute the AABB for. Must be a mesh geom.

    Returns:
        tuple: A tuple containing:
            - numpy.ndarray: The center of the AABB in world space.
            - numpy.ndarray: The x,y,z dimensions of the AABB.
    """
    assert model.geom_type[geom_id] == mujoco.mjtGeom.mjGEOM_MESH.value
    mesh_id = model.geom_dataid[geom_id]
    vertadr = model.mesh_vertadr[mesh_id]
    n_vert = model.mesh_vertnum[mesh_id]

    geom_rel_pose = pos_quat_to_pose_mat(model.geom_pos[geom_id], model.geom_quat[geom_id])
    geom_body_id = model.geom_bodyid[geom_id]
    geom_pose = body_pose(data, geom_body_id) @ geom_rel_pose

    vertices_local = model.mesh_vert[vertadr : vertadr + n_vert]
    vertices = vertices_local @ geom_pose[:3, :3].T + geom_pose[:3, 3]

    aabb_min = np.min(vertices, axis=0)
    aabb_max = np.max(vertices, axis=0)
    return (aabb_min + aabb_max) / 2, aabb_max - aabb_min

site_pose

site_pose(data: MjData, site_id: int) -> ndarray
Source code in molmo_spaces/utils/mj_model_and_data_utils.py
def site_pose(data: mujoco.MjData, site_id: int) -> np.ndarray:
    trf = np.eye(4)
    trf[:3, 3] = data.site_xpos[site_id]
    trf[:3, :3] = data.site_xmat[site_id].reshape(3, 3)
    return trf

mp_logging

Classes:

Name Description
ColoredFormatter

Format a log string with colors.

ImportChecker

Functions:

Name Description
find_free_port

Finds a free port for distributed training.

get_logger

Get a logging.Logger to stderr. It can be called whenever we wish to

get_worker_logger

Create a logger specific to a worker that includes the worker ID in all messages

init_logging

Init the logging.Logger.

restore_worker_stdout

Restore the previous stdout for the current thread

setup_worker_stdout

Set up stdout redirection for a worker thread to use the worker's logger

update_log_level
worker_stdout_context

Context manager for worker-specific stdout redirection

Attributes:

Name Type Description
HUMAN_LOG_LEVELS tuple[str, ...]

Available log levels: "debug", "info", "warning", "error", "none"

HUMAN_LOG_LEVELS module-attribute

HUMAN_LOG_LEVELS: tuple[str, ...] = ('debug', 'info', 'warning', 'error', 'none')

Available log levels: "debug", "info", "warning", "error", "none"

ColoredFormatter

ColoredFormatter(fmt: str, datefmt: str | None = None, use_color=True)

Bases: Formatter

Format a log string with colors.

This implementation taken (with modifications) from https://stackoverflow.com/a/384125.

Methods:

Name Description
format

Attributes:

Name Type Description
BOLD_SEQ
COLORS
COLOR_SEQ
RESET_SEQ
use_color
Source code in molmo_spaces/utils/mp_logging.py
def __init__(self, fmt: str, datefmt: str | None = None, use_color=True) -> None:
    super().__init__(fmt=fmt, datefmt=datefmt)
    self.use_color = use_color
BOLD_SEQ class-attribute instance-attribute
BOLD_SEQ = '\x1b[1m'
COLORS class-attribute instance-attribute
COLORS = {'WARNING': YELLOW, 'INFO': GREEN, 'DEBUG': BLUE, 'ERROR': RED, 'CRITICAL': MAGENTA}
COLOR_SEQ class-attribute instance-attribute
COLOR_SEQ = '\x1b[1;%dm'
RESET_SEQ class-attribute instance-attribute
RESET_SEQ = '\x1b[0m'
use_color instance-attribute
use_color = use_color
format
format(record: LogRecord) -> str
Source code in molmo_spaces/utils/mp_logging.py
def format(self, record: logging.LogRecord) -> str:
    levelname = record.levelname
    if self.use_color and levelname in self.COLORS:
        levelname_with_color = (
            self.COLOR_SEQ % (30 + self.COLORS[levelname]) + levelname + self.RESET_SEQ
        )
        record.levelname = levelname_with_color
        formated_record = logging.Formatter.format(self, record)
        record.levelname = levelname  # Resetting levelname as `record` might be used elsewhere
        return formated_record
    else:
        return logging.Formatter.format(self, record)

ImportChecker

ImportChecker(msg=None)

Methods:

Name Description
__enter__
__exit__

Attributes:

Name Type Description
msg
Source code in molmo_spaces/utils/mp_logging.py
def __init__(self, msg=None) -> None:
    self.msg = msg
msg instance-attribute
msg = msg
__enter__
__enter__() -> None
Source code in molmo_spaces/utils/mp_logging.py
def __enter__(self) -> None:
    pass
__exit__
__exit__(exc_type, value, traceback) -> bool
Source code in molmo_spaces/utils/mp_logging.py
def __exit__(self, exc_type, value, traceback) -> bool:
    if exc_type is ModuleNotFoundError and self.msg is not None:
        value.msg += self.msg
    return exc_type is None

find_free_port

find_free_port(address: str = '127.0.0.1') -> int

Finds a free port for distributed training.

Returns

port: port number that can be used to listen

Source code in molmo_spaces/utils/mp_logging.py
def find_free_port(address: str = "127.0.0.1") -> int:
    """Finds a free port for distributed training.

    # Returns

    port: port number that can be used to listen
    """
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind((address, 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        port = s.getsockname()[1]
    return port

get_logger

get_logger() -> Logger

Get a logging.Logger to stderr. It can be called whenever we wish to log some message. Messages can get mixed-up (https://docs.python.org/3.6/library/multiprocessing.html#logging), but it works well in most cases.

Returns

logger: the logging.Logger object

Source code in molmo_spaces/utils/mp_logging.py
def get_logger() -> logging.Logger:
    """Get a `logging.Logger` to stderr. It can be called whenever we wish to
    log some message. Messages can get mixed-up
    (https://docs.python.org/3.6/library/multiprocessing.html#logging), but it
    works well in most cases.

    # Returns

    logger: the `logging.Logger` object
    """
    if _new_logger():
        if mp.current_process().name == "MainProcess":
            _new_logger(logging.DEBUG)
        _set_log_formatter()
    return _LOGGER

get_worker_logger

get_worker_logger(worker_id: int) -> Logger

Create a logger specific to a worker that includes the worker ID in all messages

Source code in molmo_spaces/utils/mp_logging.py
def get_worker_logger(worker_id: int) -> logging.Logger:
    """Create a logger specific to a worker that includes the worker ID in all messages"""
    # Create a new logger for this worker
    worker_logger = logging.getLogger(f"worker_{worker_id}")

    # Only add handler if it doesn't already exist
    if not worker_logger.handlers:
        # Assign a consistent color to this worker based on worker_id
        # Use a hash of worker_id to get a deterministic color
        color_hash = hashlib.md5(str(worker_id).encode()).hexdigest()
        color_index = int(color_hash[:2], 16) % 8  # Use first 2 hex chars to get 0-7

        # Create a custom formatter that includes worker ID with color
        default_format = "[$BOLD%(asctime)s $WORKER_COLORWorker %(worker_id)s$RESET$RESET %(levelname)s %(filename)s:%(lineno)d] %(message)s"
        default_format = default_format.replace("$BOLD", ColoredFormatter.BOLD_SEQ).replace(
            "$RESET", ColoredFormatter.RESET_SEQ
        )

        # Create a formatter that captures the worker_id and applies color
        class WorkerFormatter(ColoredFormatter):
            def __init__(self, fmt, datefmt, worker_id, color_index) -> None:
                super().__init__(fmt, datefmt)
                self.worker_id = worker_id
                self.color_index = color_index

            def format(self, record):
                record.worker_id = self.worker_id
                # Apply the worker color
                worker_color_seq = ColoredFormatter.COLOR_SEQ % (30 + self.color_index)
                formatted = super().format(record)
                formatted = formatted.replace("$WORKER_COLOR", worker_color_seq)
                return formatted

        worker_formatter = WorkerFormatter(
            fmt=default_format,
            datefmt="%m/%d %H:%M:%S",
            worker_id=worker_id,
            color_index=color_index,
        )

        # Add console handler
        ch = logging.StreamHandler()
        ch.setFormatter(worker_formatter)
        worker_logger.addHandler(ch)

        # Add file handler if main logger has file logging enabled
        if _LOG_FILE:
            # Use the same log file as the main logger
            # Create file handler with plain formatter (no colors for file)
            class WorkerFileFormatter(logging.Formatter):
                def __init__(self, fmt, datefmt, worker_id) -> None:
                    super().__init__(fmt, datefmt)
                    self.worker_id = worker_id

                def format(self, record):
                    record.worker_id = self.worker_id
                    return super().format(record)

            worker_file_formatter = WorkerFileFormatter(
                fmt="%(asctime)s %(levelname)s: [Worker %(worker_id)s] %(message)s\t[%(filename)s: %(lineno)d]",
                datefmt="%m/%d %H:%M:%S",
                worker_id=worker_id,
            )

            fh = logging.FileHandler(_LOG_FILE)
            fh.setFormatter(worker_file_formatter)
            worker_logger.addHandler(fh)

        # Set the same level as the main logger
        main_logger = get_logger()
        worker_logger.setLevel(main_logger.getEffectiveLevel())

        # Prevent propagation to avoid duplicate logs
        worker_logger.propagate = False

        # Configure the root molmo_spaces logger to use worker logger's handlers
        # This ensures module-level loggers (like log = logging.getLogger(__name__))
        # in molmo_spaces modules work properly in worker processes
        molmo_spaces_logger = logging.getLogger("molmo_spaces")
        molmo_spaces_logger.handlers = worker_logger.handlers.copy()
        molmo_spaces_logger.setLevel(worker_logger.level)
        molmo_spaces_logger.propagate = False  # Don't propagate to root to avoid duplication

    return worker_logger

init_logging

init_logging(human_log_level: str = 'info', log_file: str | None = None) -> None

Init the logging.Logger.

It should be called only once in the app (e.g. in main). It sets the log_level to one of HUMAN_LOG_LEVELS. And sets up handlers for stderr and optionally a log file. The logging level is propagated to all subprocesses.

Parameters:

Name Type Description Default
human_log_level str

Log level as a human-readable string. One of "debug", "info", "warning", "error", "none".

'info'
log_file str | None

Optional path to a log file. If provided, logs will also be written to this file. All worker loggers will also write to the same file with worker ID prefixes.

None
Source code in molmo_spaces/utils/mp_logging.py
def init_logging(human_log_level: str = "info", log_file: str | None = None) -> None:
    """Init the `logging.Logger`.

    It should be called only once in the app (e.g. in `main`). It sets
    the log_level to one of `HUMAN_LOG_LEVELS`. And sets up handlers
    for stderr and optionally a log file. The logging level is propagated to all subprocesses.

    Args:
        human_log_level: Log level as a human-readable string. One of "debug", "info", "warning", "error", "none".
        log_file: Optional path to a log file. If provided, logs will also be written to this file.
                  All worker loggers will also write to the same file with worker ID prefixes.
    """
    global _LOG_FILE
    _LOG_FILE = log_file
    _new_logger(_human_log_level_to_int(human_log_level))
    _set_log_formatter()

restore_worker_stdout

restore_worker_stdout() -> None

Restore the previous stdout for the current thread

Source code in molmo_spaces/utils/mp_logging.py
def restore_worker_stdout() -> None:
    """Restore the previous stdout for the current thread"""
    if hasattr(_worker_logger_storage, "previous_stdout"):
        sys.stdout = _worker_logger_storage.previous_stdout
        delattr(_worker_logger_storage, "previous_stdout")
    if hasattr(_worker_logger_storage, "worker_stream"):
        delattr(_worker_logger_storage, "worker_stream")

setup_worker_stdout

setup_worker_stdout(worker_logger: Logger, worker_id: int = None) -> None

Set up stdout redirection for a worker thread to use the worker's logger

Source code in molmo_spaces/utils/mp_logging.py
def setup_worker_stdout(worker_logger: logging.Logger, worker_id: int = None) -> None:
    """Set up stdout redirection for a worker thread to use the worker's logger"""
    # Store the current stdout for this thread (before worker redirection)
    if not hasattr(_worker_logger_storage, "previous_stdout"):
        _worker_logger_storage.previous_stdout = sys.stdout

    # Extract worker_id from logger name if not provided
    if worker_id is None:
        logger_name = worker_logger.name
        if logger_name.startswith("worker_"):
            try:
                worker_id = int(logger_name.split("_")[1])
            except (IndexError, ValueError):
                worker_id = 0

    # Create worker-specific stream logger and redirect stdout
    worker_stream = _WorkerStreamToLogger(worker_logger, worker_id)
    _worker_logger_storage.worker_stream = worker_stream
    sys.stdout = cast(io.TextIOWrapper, worker_stream)

update_log_level

update_log_level(logger, human_log_level: str) -> None
Source code in molmo_spaces/utils/mp_logging.py
def update_log_level(logger, human_log_level: str) -> None:
    logger.setLevel(_human_log_level_to_int(human_log_level))

worker_stdout_context

worker_stdout_context(worker_logger: Logger, worker_id: int = None)

Context manager for worker-specific stdout redirection

Source code in molmo_spaces/utils/mp_logging.py
@contextmanager
def worker_stdout_context(worker_logger: logging.Logger, worker_id: int = None):
    """Context manager for worker-specific stdout redirection"""
    yield

mujoco_scene_utils

Functions:

Name Description
add_visual_capsule

Adds one capsule to an mjvScene.

get_supporting_geom

Finds the supporting geometry for an object, using a heuristic.

is_object_supported_by_body

Checks if an object is supported by a given body, using heuristics.

place_object_near

Place an object near a point such that the bottom of the object (i.e. the base) is at the specified z-value, with a random yaw.

randomize_door_joints

Modify door and handle joint parameters in a house spec.

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

add_visual_capsule

add_visual_capsule(scene, point1, point2, radius, rgba) -> None

Adds one capsule to an mjvScene. these geometries are automatically visual-only and don't participate in collision detection

Source code in molmo_spaces/utils/mujoco_scene_utils.py
def add_visual_capsule(scene, point1, point2, radius, rgba) -> None:
    """Adds one capsule to an mjvScene.
    these geometries are automatically visual-only and don't participate in collision detection
    """
    if scene.ngeom >= scene.maxgeom:
        return
    scene.ngeom += 1  # increment ngeom
    # initialise a new capsule, add it to the scene using mjv_connector
    mujoco.mjv_initGeom(
        scene.geoms[scene.ngeom - 1],
        mujoco.mjtGeom.mjGEOM_CAPSULE,
        np.zeros(3),
        np.zeros(3),
        np.zeros(9),
        rgba.astype(np.float32),
    )
    mujoco.mjv_connector(
        scene.geoms[scene.ngeom - 1], mujoco.mjtGeom.mjGEOM_CAPSULE, radius, point1, point2
    )

get_supporting_geom

get_supporting_geom(data: MjData, object_id: int, angle_threshold: float = radians(80)) -> int | None

Finds the supporting geometry for an object, using a heuristic. Searches for a geom in contact with the object, such that the contact is in the bottom half of the object's AABB and the normal is pointing upwards.

Parameters:

Name Type Description Default
data MjData

MjData object

required
object_id int

Body ID of the root body to find the supporting geometry for

required
angle_threshold float

Threshold for the angle between the normal and the vertical axis to be considered parallel, in radians

radians(80)

Returns:

Name Type Description
int int | None

Geom ID of the supporting geometry, or None if no supporting geometry is found

Source code in molmo_spaces/utils/mujoco_scene_utils.py
def get_supporting_geom(
    data: MjData, object_id: int, angle_threshold: float = np.radians(80)
) -> int | None:
    """
    Finds the supporting geometry for an object, using a heuristic.
    Searches for a geom in contact with the object, such that the contact is in the bottom half of the object's AABB and the normal is pointing upwards.

    Args:
        data: MjData object
        object_id: Body ID of the root body to find the supporting geometry for
        angle_threshold: Threshold for the angle between the normal and the vertical axis to be considered parallel, in radians

    Returns:
        int: Geom ID of the supporting geometry, or None if no supporting geometry is found
    """
    model = data.model
    assert model.body_rootid[object_id] == object_id, "Object is not a root body"

    try:
        body_aabb_center, _ = body_aabb(model, data, object_id, visible_only=True)
    except ValueError:
        # fallback if body doesn't have any visible geoms (usually not the case)
        body_aabb_center, _ = body_aabb(model, data, object_id, visible_only=False)
    cos_threshold = np.cos(angle_threshold)

    for c in data.contact:
        root_body1, root_body2 = model.body_rootid[model.geom_bodyid[c.geom]]
        if (root_body1 == object_id) ^ (root_body2 == object_id):
            other_geom_id = c.geom[0] if root_body1 != object_id else c.geom[1]
            normal = c.frame[:3] / np.linalg.norm(c.frame[:3])
            if root_body1 == object_id:
                normal = -normal
            if c.pos[2] < body_aabb_center[2] and normal[2] >= cos_threshold:
                return other_geom_id
    return None

is_object_supported_by_body

is_object_supported_by_body(data: MjData, object_id: int, support_id: int, angle_threshold: float = radians(30), frac_weight_threshold: float = 0.5, eps: float = 1e-06) -> bool

Checks if an object is supported by a given body, using heuristics. This is more precise than get_supporting_geom.

Parameters:

Name Type Description Default
data MjData

MjData object

required
object_id int

Body ID of the root body to check if it is supported by the supporting body

required
support_id int

Body ID of the supporting body to check if it is supporting the object

required
angle_threshold float

Threshold for the angle between the normal and the vertical axis to be considered parallel, in radians

radians(30)
frac_weight_threshold float

The upward component of the contact force must be at least this fraction of the object weight to be considered supported

0.5
eps float

Threshold for the net contact force to be considered non-zero

1e-06

Returns:

Name Type Description
bool bool

True if the object is supported by the given support, False otherwise

Source code in molmo_spaces/utils/mujoco_scene_utils.py
def is_object_supported_by_body(
    data: MjData,
    object_id: int,
    support_id: int,
    angle_threshold: float = np.radians(30),
    frac_weight_threshold: float = 0.5,
    eps: float = 1e-6,
) -> bool:
    """
    Checks if an object is supported by a given body, using heuristics.
    This is more precise than get_supporting_geom.

    Args:
        data: MjData object
        object_id: Body ID of the root body to check if it is supported by the supporting body
        support_id: Body ID of the supporting body to check if it is supporting the object
        angle_threshold: Threshold for the angle between the normal and the vertical axis to be considered parallel, in radians
        frac_weight_threshold: The upward component of the contact force must be at least this fraction of the object weight to be considered supported
        eps: Threshold for the net contact force to be considered non-zero

    Returns:
        bool: True if the object is supported by the given support, False otherwise
    """
    model = data.model
    assert model.body_rootid[object_id] == object_id, "Object is not a root body"
    body_rootid = model.body_rootid[support_id]

    net_force = np.zeros(3)

    for cid in range(data.ncon):
        c = data.contact[cid]
        root_body1, root_body2 = model.body_rootid[model.geom_bodyid[c.geom]]

        # only check contacts between the object and the body
        if {root_body1, root_body2} == {body_rootid, object_id}:
            contact_force = np.zeros(6)
            mujoco.mj_contactForce(model, data, cid, contact_force)
            if root_body1 == object_id:
                contact_force = -contact_force

            contact_rotmat = c.frame.reshape(3, 3).T
            contact_force_world = contact_rotmat @ contact_force[:3]
            net_force += contact_force_world

    if np.linalg.norm(net_force) < eps:
        # no contact between objects
        return False

    cos_threshold = np.cos(angle_threshold)
    cos_to_z = net_force[2] / np.linalg.norm(net_force).item()
    contact_is_vertical = cos_to_z >= cos_threshold

    object_weight: float = (
        model.body_subtreemass[object_id] * np.linalg.norm(model.opt.gravity)
    ).item()
    is_supporting_weight = np.abs(net_force[2]).item() >= frac_weight_threshold * object_weight

    # TODO: this fails to capture transitive support, e.g. if the object is on another body which is on the support.
    # This is an unlikely edge case, so we'll leave it for now.
    # Future solutions could consider building a contact graph between the object and the support's support (e.g. the table)
    # and checking that a sufficient amount of weight is bottlenecked through the support node. This would handle
    # both transitive support and multiple supports.

    return contact_is_vertical and is_supporting_weight

place_object_near

place_object_near(data: MjData, object_id: int, placement_point: ndarray, min_dist: float, max_dist: float, max_tries: int = 100, reference_pos: ndarray | None = None, max_dist_to_reference: float = 1.0, supporting_geom_id: int | None = None, z_eps: float = 0.001)

Place an object near a point such that the bottom of the object (i.e. the base) is at the specified z-value, with a random yaw. Optionally, ensure the placed object is within a certain distance of a reference position.

Parameters:

Name Type Description Default
data MjData

MjData object

required
object_id int

ID of the object to place

required
placement_point ndarray

Point to place the object near

required
min_dist float

Minimum distance from the placement point

required
max_dist float

Maximum distance from the placement point

required
max_tries int

Maximum number of placement attempts

100
reference_pos ndarray | None

Reference position to place the object near

None
max_dist_to_reference float

Maximum distance to the reference position

1.0
supporting_geom_id int | None

ID of the supporting geometry to optionally ensure the object is placed on top of

None
z_eps float

Epsilon to add to the z-offset to avoid collision

0.001

Raises:

Type Description
ObjectPlacementError

If the object cannot be placed within the specified number of attempts

Source code in molmo_spaces/utils/mujoco_scene_utils.py
def place_object_near(
    data: MjData,
    object_id: int,
    placement_point: np.ndarray,
    min_dist: float,
    max_dist: float,
    max_tries: int = 100,
    reference_pos: np.ndarray | None = None,
    max_dist_to_reference: float = 1.0,
    supporting_geom_id: int | None = None,
    z_eps: float = 1e-3,
):
    """
    Place an object near a point such that the bottom of the object (i.e. the base) is at the specified z-value, with a random yaw.
    Optionally, ensure the placed object is within a certain distance of a reference position.

    Args:
        data: MjData object
        object_id: ID of the object to place
        placement_point: Point to place the object near
        min_dist: Minimum distance from the placement point
        max_dist: Maximum distance from the placement point
        max_tries: Maximum number of placement attempts
        reference_pos: Reference position to place the object near
        max_dist_to_reference: Maximum distance to the reference position
        supporting_geom_id: ID of the supporting geometry to optionally ensure the object is placed on top of
        z_eps: Epsilon to add to the z-offset to avoid collision

    Raises:
        ObjectPlacementError: If the object cannot be placed within the specified number of attempts
    """
    object_body = create_mlspaces_body(data, object_id)
    original_pose = object_body.pose

    body_aabb_center, body_aabb_size = body_aabb(data.model, data, object_id)
    z_offset = object_body.position[2] - (body_aabb_center[2] - body_aabb_size[2] / 2)

    if supporting_geom_id is not None:
        support_geom_aabb_center, support_geom_aabb_size = geom_aabb(
            data.model, data, [supporting_geom_id]
        )
        placement_pos_min = (
            support_geom_aabb_center[:2] - support_geom_aabb_size[:2] / 2 + body_aabb_size[:2] / 4
        )
        placement_pos_max = (
            support_geom_aabb_center[:2] + support_geom_aabb_size[:2] / 2 - body_aabb_size[:2] / 4
        )
        supporting_root_body_id = data.model.body_rootid[data.model.geom_bodyid[supporting_geom_id]]
        placement_point = placement_point.copy()
        placement_point[2] = support_geom_aabb_center[2] + support_geom_aabb_size[2] / 2
    else:
        placement_pos_min = np.full(2, -np.inf)
        placement_pos_max = np.full(2, np.inf)
        supporting_root_body_id = None

    # first generate the candidate placement positions that satisfy the distance constraints
    candidate_placement_pos_xy = np.zeros((max_tries, 2))
    n_candidates = 0
    i = 0
    while n_candidates < max_tries:
        N = 1024
        azimuth = np.random.uniform(-np.pi, np.pi, N)
        distance = np.random.uniform(min_dist, max_dist, N)
        xy_offset = distance.reshape(-1, 1) * np.stack([np.cos(azimuth), np.sin(azimuth)], axis=1)
        placement_pos_xy = placement_point[:2][None] + xy_offset

        eligible_mask = np.all(
            (placement_pos_min <= placement_pos_xy) & (placement_pos_xy <= placement_pos_max),
            axis=1,
        )
        if reference_pos is not None:
            dist_to_reference = np.linalg.norm(placement_pos_xy - reference_pos[:2][None], axis=1)
            eligible_mask &= dist_to_reference <= max_dist_to_reference

        new_n_candidates = min(max_tries, n_candidates + eligible_mask.sum())
        candidate_placement_pos_xy[n_candidates:new_n_candidates] = placement_pos_xy[eligible_mask][
            : new_n_candidates - n_candidates
        ]
        n_candidates = new_n_candidates

        i += 1
        if i >= max_tries:
            log.debug(
                f"Failed to sample {max_tries} candidate placement positions within {max_tries} attempts"
            )
            candidate_placement_pos_xy = candidate_placement_pos_xy[:n_candidates]
            break

    # for each candidate placement position, try to place the object and check for collisions
    for attempt, placement_pos_xy in enumerate(candidate_placement_pos_xy):
        yaw = np.random.uniform(-np.pi, np.pi)
        placement_pos = np.array(
            [placement_pos_xy[0], placement_pos_xy[1], placement_point[2] + z_offset + z_eps]
        )
        placement_pose = np.eye(4)
        placement_pose[:3, 3] = placement_pos
        placement_pose[:3, :3] = R.from_euler("z", yaw).as_matrix() @ original_pose[:3, :3]
        object_body.pose = placement_pose

        mujoco.mj_fwdPosition(data.model, data)

        in_collision = False
        # TODO(abhayd): do we need to place on the same surface? Why not just any surface?
        for c in data.contact:
            root_body1 = data.model.body_rootid[data.model.geom_bodyid[c.geom1]]
            root_body2 = data.model.body_rootid[data.model.geom_bodyid[c.geom2]]
            if (root_body1 == object_id) ^ (root_body2 == object_id):
                other_root_body = root_body1 if root_body1 != object_id else root_body2
                if other_root_body != supporting_root_body_id:
                    in_collision = True
                    break

        if not in_collision:
            log.debug(
                f"Successfully placed object with ID {object_id} after {attempt + 1} attempts"
            )
            break

    else:
        object_body.pose = original_pose
        mujoco.mj_forward(data.model, data)
        raise ObjectPlacementError(
            f"Failed to place object with ID {object_id} within {max_tries} attempts"
        )

randomize_door_joints

randomize_door_joints(spec: MjSpec, scene_metadata: dict, door_stiffness_range: tuple = (3, 7), door_damping_range: tuple = (8, 12), door_frictionloss_range: tuple = (8, 12), handle_stiffness_range: tuple = (200, 300), handle_damping_range: tuple = (80, 120), handle_frictionloss_range: tuple = (40, 60), add_handle_limits: bool = True) -> None

Modify door and handle joint parameters in a house spec.

This function identifies door joints and handle joints by their naming patterns and modifies their physical parameters (stiffness, damping, frictionloss) with randomized values within specified ranges.

It also sets the ref and springref based on range heuristics.

Parameters:

Name Type Description Default
spec MjSpec

The model spec

required
door_stiffness_range tuple

(min, max) range for door joint stiffness (default: reduce from ~250 to 3-7)

(3, 7)
door_damping_range tuple

(min, max) range for door joint damping (default: reduce from ~100 to 8-12)

(8, 12)
door_frictionloss_range tuple

(min, max) range for door joint frictionloss (default: reduce from ~50 to 8-12)

(8, 12)
handle_stiffness_range tuple

(min, max) range for handle joint stiffness (default: increase from ~0 to 200-300)

(200, 300)
handle_damping_range tuple

(min, max) range for handle joint damping (default: increase from ~0.1 to 80-120)

(80, 120)
handle_frictionloss_range tuple

(min, max) range for handle joint frictionloss (default: increase from ~0 to 40-60)

(40, 60)
add_handle_limits bool

Whether to add limited="true" and ref/springref attributes to handle joints

True
Source code in molmo_spaces/utils/mujoco_scene_utils.py
def randomize_door_joints(  # TODO: do these defaults make sense?
    spec: MjSpec,
    scene_metadata: dict,
    door_stiffness_range: tuple = (3, 7),
    door_damping_range: tuple = (8, 12),
    door_frictionloss_range: tuple = (8, 12),
    handle_stiffness_range: tuple = (200, 300),
    handle_damping_range: tuple = (80, 120),
    handle_frictionloss_range: tuple = (40, 60),
    add_handle_limits: bool = True,
) -> None:
    """
    Modify door and handle joint parameters in a house spec.

    This function identifies door joints and handle joints by their naming patterns and
    modifies their physical parameters (stiffness, damping, frictionloss) with randomized
    values within specified ranges.

    It also sets the ref and springref based on range heuristics.

    Args:
        spec: The model spec
        door_stiffness_range: (min, max) range for door joint stiffness (default: reduce from ~250 to 3-7)
        door_damping_range: (min, max) range for door joint damping (default: reduce from ~100 to 8-12)
        door_frictionloss_range: (min, max) range for door joint frictionloss (default: reduce from ~50 to 8-12)
        handle_stiffness_range: (min, max) range for handle joint stiffness (default: increase from ~0 to 200-300)
        handle_damping_range: (min, max) range for handle joint damping (default: increase from ~0.1 to 80-120)
        handle_frictionloss_range: (min, max) range for handle joint frictionloss (default: increase from ~0 to 40-60)
        add_handle_limits: Whether to add limited="true" and ref/springref attributes to handle joints
    """
    log.debug("Starting joint modifications")
    scene_objects = scene_metadata.get("objects", {})
    log.debug(f"Found {len(scene_objects)} scene objects to check")
    handle_joints = []
    door_joints = []
    for key, value in scene_objects.items():
        if "doorway" in key:
            name_map = value.get("name_map", {})
            joints = name_map.get("joints", {})
            if len(joints) > 0:
                for joint in joints:
                    if "handle" in joint:
                        handle_joints.append(joint)
                    else:
                        door_joints.append(joint)

    log.debug(f"Found {len(door_joints)} door joints: {door_joints}")
    log.debug(f"Found {len(handle_joints)} handle joints: {handle_joints}")

    modifications_count = {"doors": 0, "handles": 0}

    for joint in spec.joints:
        joint: mujoco.MjsJoint

        name = joint.name or ""
        is_door_joint = name in door_joints
        is_handle_joint = name in handle_joints

        if not (is_door_joint or is_handle_joint):
            continue

        if is_door_joint:
            log.debug(f"[DOOR JOINT MOD] Modifying door joint: {name}")

            # only adust the ref/springref if the joint is not already set to be a spring
            if joint.stiffness == 0.0:
                joint.springref = joint.range[0].item()

            old_stiffness = joint.stiffness
            old_damping = joint.damping
            old_frictionloss = joint.frictionloss
            joint.stiffness = np.random.uniform(*door_stiffness_range)
            joint.damping = np.random.uniform(*door_damping_range)
            joint.frictionloss = np.random.uniform(*door_frictionloss_range)
            joint.limited = 1
            if joint.armature == 0.0:
                joint.armature = 1.0

            modifications_count["doors"] += 1

            log.debug(f"[DOOR JOINT MOD] stiffness: {old_stiffness} -> {joint.stiffness}")
            log.debug(f"[DOOR JOINT MOD] damping: {old_damping} -> {joint.damping}")
            log.debug(f"[DOOR JOINT MOD] frictionloss: {old_frictionloss} -> {joint.frictionloss}")
        elif is_handle_joint:
            log.debug(f"Modifying handle joint: {name}")

            if add_handle_limits:
                joint.limited = 1
                if joint.stiffness == 0.0:
                    joint.springref = joint.range[0].item()

            old_stiffness = joint.stiffness
            old_damping = joint.damping
            old_frictionloss = joint.frictionloss
            joint.stiffness = np.random.uniform(*handle_stiffness_range)
            joint.damping = np.random.uniform(*handle_damping_range)
            joint.frictionloss = np.random.uniform(*handle_frictionloss_range)
            joint.armature = 1.0  # TODO(abhay): why set armature to 1.0?
            modifications_count["handles"] += 1

            log.debug(f"[HANDLE JOINT MOD] stiffness: {old_stiffness} -> {joint.stiffness}")
            log.debug(f"[HANDLE JOINT MOD] damping: {old_damping} -> {joint.damping}")
            log.debug(
                f"[HANDLE JOINT MOD] frictionloss: {old_frictionloss} -> {joint.frictionloss}"
            )

    log.debug(
        f"Completed joint modifications: {modifications_count['doors']} door joints, {modifications_count['handles']} handle joints"
    )

object_metadata

Classes:

Name Description
DictUnion

Union of multiple nonoverlapping dictionaries.

ObjectMeta
UserLibraryMetadata

Class which provides dict-like access to a user library metadata.

Functions:

Name Description
clip_sim
compute_text_clip
get_annotation
get_clip_model
get_db
get_metadata_lmdb_dir

Attributes:

Name Type Description
DEFAULT_CLIP_MODEL
DEFAULT_CLIP_PRETRAIN
DEFAULT_DEVICE
all_descriptions
all_descs
all_sims
asset_id
asset_ids
description
descriptions
img
text_clips
texts

DEFAULT_CLIP_MODEL module-attribute

DEFAULT_CLIP_MODEL = 'ViT-L-14'

DEFAULT_CLIP_PRETRAIN module-attribute

DEFAULT_CLIP_PRETRAIN = 'laion2b_s32b_b82k'

DEFAULT_DEVICE module-attribute

DEFAULT_DEVICE = 'cpu'

all_descriptions module-attribute

all_descriptions = all_descriptions(asset_ids)

all_descs module-attribute

all_sims module-attribute

all_sims = squeeze()

asset_id module-attribute

asset_id = 'Alarm_Clock_1'

asset_ids module-attribute

asset_ids = ['Alarm_Clock_1', 'Wall_Decor_Photo_9']

description module-attribute

descriptions module-attribute

descriptions = list((annotation(asset_id)['description']) for asset_id in asset_ids)

img module-attribute

text_clips module-attribute

text_clips = compute_text_clip(texts)

texts module-attribute

texts = ['a tree', 'a watch', 'a tower clock', 'a clock', 'an alarm clock', 'a blue alarm clock']

DictUnion

DictUnion(*dicts, raise_on_missing: bool = False)

Bases: Mapping

Union of multiple nonoverlapping dictionaries. This will not check for key collisions between dictionaries!

Parameters:

Name Type Description Default
*dicts

The dictionaries to union.

()
raise_on_missing bool

Whether to raise an error if a key is not found in any of the dictionaries.

False

Methods:

Name Description
__contains__
__getitem__
__iter__
__len__
get
Source code in molmo_spaces/utils/object_metadata.py
def __init__(self, *dicts, raise_on_missing: bool = False):
    self._dicts = list(dicts)
    self._raise_on_missing = raise_on_missing
__contains__
__contains__(key)
Source code in molmo_spaces/utils/object_metadata.py
def __contains__(self, key):
    return any(key in d for d in self._dicts)
__getitem__
__getitem__(key)
Source code in molmo_spaces/utils/object_metadata.py
def __getitem__(self, key):
    for d in self._dicts:
        if key in d:
            return d[key]
    if self._raise_on_missing:
        raise KeyError(f"Key {key} not found in any of the dictionaries")
    return None
__iter__
__iter__()
Source code in molmo_spaces/utils/object_metadata.py
def __iter__(self):
    return chain.from_iterable(self._dicts)
__len__
__len__()
Source code in molmo_spaces/utils/object_metadata.py
def __len__(self):
    return sum(len(d) for d in self._dicts)
get
get(key, default=None)
Source code in molmo_spaces/utils/object_metadata.py
def get(self, key, default=None):
    for d in self._dicts:
        if key in d:
            return d[key]
    return default

ObjectMeta

Methods:

Name Description
all_descriptions
all_uids
annotation
clean_object_name
description_text_features
get_features
get_short_description
get_target_object_uid
img_features
short_descriptions
all_descriptions classmethod
all_descriptions(asset_ids: str | list[str])
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def all_descriptions(cls, asset_ids: str | list[str]):
    def func(asset_id):
        anno = cls.annotation(asset_id)
        return cls.short_descriptions(asset_id) + [
            anno["description"],
            anno["description_long"],
        ]

    if isinstance(asset_ids, str):
        return func(asset_ids)

    return [func(asset_id) for asset_id in asset_ids]
all_uids staticmethod
all_uids() -> list[str]
Source code in molmo_spaces/utils/object_metadata.py
@staticmethod
def all_uids() -> list[str]:
    return list(get_db().keys())
annotation staticmethod
annotation(asset_ids: str | list[str] | None = None) -> list[dict | None] | dict | None
Source code in molmo_spaces/utils/object_metadata.py
@staticmethod
def annotation(asset_ids: str | list[str] | None = None) -> list[dict | None] | dict | None:
    container = get_db()

    if asset_ids is None:
        return container

    if isinstance(asset_ids, str):
        return container.get(asset_ids, None)

    return [container[asset_id] for asset_id in asset_ids]
clean_object_name classmethod
clean_object_name(task: BaseMujocoTask) -> str
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def clean_object_name(cls, task: "BaseMujocoTask") -> str:
    # TODO ported from PromptSampler late at night, might need clean-up
    pickup_obj_name = task.config.task_config.pickup_obj_name

    # Check if this is a custom object with a provided name
    if pickup_obj_name.startswith("custom_object/"):
        eval_params = task.config.eval_runtime_params
        if eval_params and eval_params.custom_object_name:
            # Return the provided custom object name directly
            return eval_params.custom_object_name.capitalize()

    return cls.get_short_description(cls.get_target_object_uid(task))[0]
description_text_features classmethod
description_text_features(asset_ids: str | list[str] | None = None)
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def description_text_features(cls, asset_ids: str | list[str] | None = None):
    return cls.get_features("clip_text_features", asset_ids)
get_features classmethod
get_features(feature_type_str: str, asset_ids: str | list[str] | None = None)
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def get_features(cls, feature_type_str: str, asset_ids: str | list[str] | None = None):
    db = get_db()

    if asset_ids is None:
        return np.stack([item[feature_type_str] for item in db.values()], axis=0)

    if isinstance(asset_ids, str):
        if asset_ids in db:
            return np.array([db[asset_ids][feature_type_str]])
        raise KeyError(f"Missing {asset_ids} for {feature_type_str}")

    if any(asset_id not in db for asset_id in asset_ids):
        raise KeyError(f"Missing some of {len(asset_ids)} asset_ids for {feature_type_str}")

    return np.stack([db[asset_id][feature_type_str] for asset_id in asset_ids], axis=0)
get_short_description classmethod
get_short_description(object_uid)
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def get_short_description(cls, object_uid):
    # TODO ported from PromptSampler late at night, might need clean-up
    short_descriptions = cls.short_descriptions(object_uid)
    if len(short_descriptions) == 0:
        return [object_uid.split("_")[0]] * 4
    return short_descriptions
get_target_object_uid staticmethod
get_target_object_uid(task) -> str
Source code in molmo_spaces/utils/object_metadata.py
@staticmethod
def get_target_object_uid(task) -> str:
    # TODO ported from PromptSampler late at night, might need clean-up
    scene_metadata = task.env.current_scene_metadata["objects"]
    pickup_obj_name = task.config.task_config.pickup_obj_name
    object_metadata = scene_metadata[pickup_obj_name]
    asset_uid = object_metadata["asset_id"]
    return asset_uid
img_features classmethod
img_features(asset_ids: str | list[str] | None = None)
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def img_features(cls, asset_ids: str | list[str] | None = None):
    return cls.get_features("clip_img_features", asset_ids)
short_descriptions classmethod
short_descriptions(asset_ids: str | list[str])
Source code in molmo_spaces/utils/object_metadata.py
@classmethod
def short_descriptions(cls, asset_ids: str | list[str]):
    def func(asset_id):
        return list(cls.annotation(asset_id)["description_short"].values())

    if isinstance(asset_ids, str):
        if asset_ids in cls.annotation():
            return func(asset_ids)
        return []

    return [func(asset_id) for asset_id in asset_ids]

UserLibraryMetadata

UserLibraryMetadata(user_library_path: Path, user_library_index: dict[str, UserAssetLibraryIndexEntry], lru_cache_size: int = 1000)

Bases: Mapping

Class which provides dict-like access to a user library metadata.

Methods:

Name Description
__contains__
__getitem__
__iter__
__len__
Source code in molmo_spaces/utils/object_metadata.py
def __init__(
    self,
    user_library_path: Path,
    user_library_index: dict[str, UserAssetLibraryIndexEntry],
    lru_cache_size: int = 1000,
):
    self._user_library_path = user_library_path
    self._user_library_index = user_library_index

    @lru_cache(maxsize=lru_cache_size)
    def _get(key):
        metadata_entry = self._user_library_index[key]
        metadata_rel_path = metadata_entry.metadata_path
        with open(self._user_library_path / metadata_rel_path, "r") as f:
            metadata: dict = json.load(f)

        if metadata_entry.metadata_npz_path is not None:
            metadata_npz = np.load(self._user_library_path / metadata_entry.metadata_npz_path)

            for npz_key, npz_value in metadata_npz.items():
                d = metadata
                key_parts = npz_key.split("/")
                for k_part in key_parts[:-1]:
                    d = d.setdefault(k_part, {})
                k = key_parts[-1]
                assert k not in d, (
                    f"Key {npz_key} from npz metadata already exists in metadata for uid={key}"
                )
                d[k] = npz_value

        return metadata

    self._get = _get
__contains__
__contains__(key)
Source code in molmo_spaces/utils/object_metadata.py
def __contains__(self, key):
    return key in self._user_library_index
__getitem__
__getitem__(key)
Source code in molmo_spaces/utils/object_metadata.py
def __getitem__(self, key):
    return self._get(key)
__iter__
__iter__()
Source code in molmo_spaces/utils/object_metadata.py
def __iter__(self):
    return iter(self._user_library_index)
__len__
__len__()
Source code in molmo_spaces/utils/object_metadata.py
def __len__(self):
    return len(self._user_library_index)

clip_sim

clip_sim(clip_img, clip_text, normalize=True, num_views=3)
Source code in molmo_spaces/utils/object_metadata.py
def clip_sim(clip_img, clip_text, normalize=True, num_views=3):
    assert clip_img.shape[-2] == num_views, f"expected {num_views} feature vectors for img modality"

    if normalize:
        clip_text = clip_text / np.linalg.norm(clip_text, axis=-1, keepdims=True)
        clip_img = clip_img / np.linalg.norm(clip_img, axis=-1, keepdims=True)

    if clip_img.ndim == 2:
        clip_img = clip_img[None, ...]

    if clip_text.ndim == 1:
        clip_text = clip_text[None, :]

    assert clip_img.ndim == 3 and clip_text.ndim == 2, (
        f"Received clip_img shape {clip_img.shape} and clip_text shape {clip_text.shape}"
    )

    def func(i):
        return np.einsum("od,qd->oq", clip_img[:, i, :], clip_text)

    # We store 2 x o x q similarity scores, instead of worst case 3 x o x q - irrelevant
    res = func(0)
    for i in range(1, num_views):
        res = np.maximum(res, func(i))

    return res

compute_text_clip

compute_text_clip(text_list: str | list[str] | list[list[str]])
Source code in molmo_spaces/utils/object_metadata.py
def compute_text_clip(text_list: str | list[str] | list[list[str]]):
    import torch

    if isinstance(text_list, str):
        text_list = [text_list]
    elif isinstance(text_list, list) and isinstance(text_list[0], list):
        text_list = sum(text_list, [])

    clip = get_clip_model()
    with torch.no_grad():
        return (
            clip["model"]
            .encode_text(clip["tokenizer"](text_list).to(DEFAULT_DEVICE))
            .cpu()
            .numpy()
            .astype("float16")
        )

get_annotation

get_annotation()
Source code in molmo_spaces/utils/object_metadata.py
def get_annotation():
    with gzip.open(ASSETS_DIR / "objects" / "objathor_metadata" / "objects_metadata.json.gz") as f:
        return json.load(f)

get_clip_model

get_clip_model()
Source code in molmo_spaces/utils/object_metadata.py
def get_clip_model():
    global _CLIP

    if _CLIP is None:
        clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
            DEFAULT_CLIP_MODEL, pretrained=DEFAULT_CLIP_PRETRAIN, device=DEFAULT_DEVICE
        )
        clip_tokenizer = open_clip.get_tokenizer(DEFAULT_CLIP_MODEL)
        _CLIP = dict(model=clip_model, preprocess=clip_preprocess, tokenizer=clip_tokenizer)

    return _CLIP

get_db

get_db()
Source code in molmo_spaces/utils/object_metadata.py
def get_db():
    global _DB

    lmb_dir = get_metadata_lmdb_dir()
    if lmb_dir is None:
        return None

    if _DB is None:
        if not PickleLMDBMap.database_exists(lmb_dir):
            lmb_dir.mkdir(parents=True, exist_ok=True)
            with FileLock(str(lmb_dir / ".creation_lock")):
                if not PickleLMDBMap.database_exists(lmb_dir):
                    # Ensure metadata are installed
                    get_resource_manager()

                    clip = _get_clip_features()
                    annotation = get_annotation()

                    assert len(clip["uids"]) == len(annotation), (
                        f"# clip entries {len(clip['uids'])} != # annotation entries {len(annotation)}"
                    )

                    # First, combine annotation and features under a single dict
                    combined = {}
                    for it, uid in tqdm(enumerate(clip["uids"]), "Aggregating metadata for LMDB"):
                        combined[uid] = annotation[uid]
                        combined[uid]["clip_img_features"] = clip["img_features"][it].copy()
                        combined[uid]["clip_text_features"] = clip["text_features"][it].copy()

                    clip.clear()
                    annotation.clear()

                    PickleLMDBMap.from_dict(combined, lmb_dir)

        _DB = PickleLMDBMap(lmb_dir)

    dicts = [_DB]
    for user_library_dir in USER_ASSET_LIBRARIES.values():
        dicts.append(_get_user_library_metadata(user_library_dir))
    return DictUnion(*dicts)

get_metadata_lmdb_dir

get_metadata_lmdb_dir()
Source code in molmo_spaces/utils/object_metadata.py
def get_metadata_lmdb_dir():
    metadata_version = DATA_TYPE_TO_SOURCE_TO_VERSION.get("objects", {}).get(
        "objathor_metadata", None
    )
    if metadata_version is None:
        return None
    return ASSETS_DIR / ".lmdb" / "objathor_metadata" / metadata_version

object_retriever

Classes:

Name Description
ObjectRetriever

Attributes:

Name Type Description
anno
r

anno module-attribute

anno = annotation(uid)

r module-attribute

ObjectRetriever

ObjectRetriever(sim_thres: float = 0.5, max_results: int = 50)

Methods:

Name Description
get_keys_values
query

Attributes:

Name Type Description
max_results
storage_path
thres
Source code in molmo_spaces/utils/object_retriever.py
def __init__(self, sim_thres: float = 0.5, max_results: int = 50):
    self.thres = sim_thres
    self.max_results = max_results
    self.tk, self.ik, self.v = self.get_keys_values()
max_results instance-attribute
max_results = max_results
storage_path class-attribute instance-attribute
storage_path = ASSETS_DIR / '.lmdb' / 'object_retriever'
thres instance-attribute
thres = sim_thres
get_keys_values
get_keys_values()
Source code in molmo_spaces/utils/object_retriever.py
def get_keys_values(self):
    # Take 7xx MiB of disk space. It's okay
    if not PickleLMDBMap.database_exists(self.storage_path):
        txt_keys, img_keys, values = [], [], []
        for uid, anno in ObjectMeta.annotation().items():
            values.append(uid)

            txt_clip = anno["clip_text_features"]
            txt_clip = (txt_clip / np.linalg.norm(txt_clip)).astype("float16")
            txt_keys.append(txt_clip)

            img_clip = anno["clip_img_features"]
            img_clip = (img_clip / np.linalg.norm(img_clip, axis=-1, keepdims=True)).astype(
                "float16"
            )
            img_keys.append(img_clip)

        txt_keys = np.array(txt_keys)[:, None, :]
        img_keys = np.array(img_keys)
        values = np.array(values)

        PickleLMDBMap.from_dict(
            dict(txt_keys=txt_keys, img_keys=img_keys, values=values), self.storage_path
        )

        del txt_keys, img_keys, values

    map = PickleLMDBMap(self.storage_path)

    # keep them all in memory
    return map["txt_keys"], map["img_keys"], map["values"]
query
query(text)
Source code in molmo_spaces/utils/object_retriever.py
def query(self, text):
    q = compute_text_clip(text)
    q = q / np.linalg.norm(q)

    sim = (
        clip_sim(self.ik, q, normalize=False)
        + 0.5 * clip_sim(self.tk, q, normalize=False, num_views=1)
    ).flatten()

    mask = sim >= self.thres
    rank = np.argsort(sim[mask])[::-1][: self.max_results]

    uids = self.v[mask][rank]
    sims = sim[mask][rank]

    return uids, sims

patch_renderer_flags

Import this module to configure the renderer flags for the current platform.

Functions:

Name Description
patch_renderer_flags

patch_renderer_flags

patch_renderer_flags()
Source code in molmo_spaces/utils/patch_renderer_flags.py
def patch_renderer_flags():
    if sys.platform.startswith("linux"):
        if "MUJOCO_GL" not in os.environ:
            os.environ["MUJOCO_GL"] = "egl"
        if "PYOPENGL_PLATFORM" not in os.environ:
            os.environ["PYOPENGL_PLATFORM"] = "egl"
        if "MUJOCO_EGL_DEVICE_ID" not in os.environ:
            os.environ["MUJOCO_EGL_DEVICE_ID"] = "0"

pose

Functions:

Name Description
compute_lookat_forward_up

Compute forward and up unit vectors for a camera looking at a target.

pos_quat_to_pose_mat
pose_mat_to_7d

Convert 4x4 pose matrix to 7D vector (x, y, z ,qw, qx, qy, qz).

pose_mat_to_pos_quat

compute_lookat_forward_up

compute_lookat_forward_up(camera_pos: ndarray, lookat_target: ndarray, camera_up: ndarray | None = None) -> tuple[ndarray, ndarray]

Compute forward and up unit vectors for a camera looking at a target.

Parameters:

Name Type Description Default
camera_pos ndarray

Camera position in world frame.

required
lookat_target ndarray

Point to look at in world frame.

required
camera_up ndarray | None

Desired up direction. Defaults to world Z-up [0, 0, 1].

None

Returns:

Type Description
tuple[ndarray, ndarray]

(forward, up) unit vectors in world frame.

Source code in molmo_spaces/utils/pose.py
def compute_lookat_forward_up(
    camera_pos: np.ndarray,
    lookat_target: np.ndarray,
    camera_up: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute forward and up unit vectors for a camera looking at a target.

    Args:
        camera_pos: Camera position in world frame.
        lookat_target: Point to look at in world frame.
        camera_up: Desired up direction. Defaults to world Z-up [0, 0, 1].

    Returns:
        (forward, up) unit vectors in world frame.
    """
    forward = np.asarray(lookat_target, dtype=float) - np.asarray(camera_pos, dtype=float)
    forward = forward / np.linalg.norm(forward)

    if camera_up is None:
        camera_up = np.array([0.0, 0.0, 1.0])

    right = np.cross(forward, camera_up)
    right_norm = np.linalg.norm(right)

    if right_norm < 1e-6:
        fallback_ref = np.array([1.0, 0.0, 0.0])
        if np.abs(np.dot(forward, fallback_ref)) > 0.9:
            fallback_ref = np.array([0.0, 1.0, 0.0])
        right = np.cross(forward, fallback_ref)
        right = right / np.linalg.norm(right)
    else:
        right = right / right_norm

    up = np.cross(right, forward)
    return forward, up

pos_quat_to_pose_mat

pos_quat_to_pose_mat(pos: ndarray | list, quat: ndarray | list | None = None) -> ndarray
Source code in molmo_spaces/utils/pose.py
def pos_quat_to_pose_mat(
    pos: np.ndarray | list, quat: np.ndarray | list | None = None
) -> np.ndarray:
    if quat is None:
        assert len(pos) == 7
        quat = pos[3:7]
        pos = pos[0:3]

    assert len(pos) == 3
    assert len(quat) == 4
    pose_matrix = np.eye(4)
    pose_matrix[:3, :3] = R.from_quat(quat, scalar_first=True).as_matrix()
    pose_matrix[:3, 3] = pos
    return pose_matrix

pose_mat_to_7d

pose_mat_to_7d(pose_matrix: ndarray) -> ndarray

Convert 4x4 pose matrix to 7D vector (x, y, z ,qw, qx, qy, qz).

Source code in molmo_spaces/utils/pose.py
def pose_mat_to_7d(pose_matrix: np.ndarray) -> np.ndarray:
    """Convert 4x4 pose matrix to 7D vector (x, y, z ,qw, qx, qy, qz)."""
    assert pose_matrix.shape == (4, 4)
    pos = pose_matrix[:3, 3]
    rot_quat = R.from_matrix(pose_matrix[:3, :3]).as_quat(scalar_first=True)  # Returns [w, x, y, z]
    return np.concatenate([pos, rot_quat])

pose_mat_to_pos_quat

pose_mat_to_pos_quat(pose: ndarray) -> tuple[ndarray, ndarray]
Source code in molmo_spaces/utils/pose.py
def pose_mat_to_pos_quat(pose: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    pos = pose[:3, 3]
    quat = R.from_matrix(pose[:3, :3]).as_quat(scalar_first=True)
    return pos, quat

profiler_utils

Classes:

Name Description
DatagenProfiler

Per-worker profiler for distributed data generation that accumulates timing stats

MutableFloat
Profiler

Functions:

Name Description
Timer

DatagenProfiler

DatagenProfiler(logger=None, enabled: bool = True)

Per-worker profiler for distributed data generation that accumulates timing stats across episodes and houses, logging summaries to the worker logger.

Tracks operations like: - task_sampling: Time to sample a task from the task sampler - policy_setup: Time to create/setup the policy - rollout_total: Total time for a rollout (reset + all steps) - rollout_reset: Time for task.reset() - policy_get_action: Time for policy.get_action() calls (per step, accumulated) - task_step: Time for task.step() calls (per step, accumulated) - episode_total: Total time for one episode (sampling + policy setup + rollout) - save_batch_prep: Time to prepare episode for saving - save_trajectories: Time to save trajectory data

Usage

profiler = DatagenProfiler(logger)

For each episode:

with profiler.profile("task_sampling"): task = task_sampler.sample_task(...)

After each episode:

profiler.log_episode_summary(episode_idx, house_id)

After each house:

profiler.log_house_summary(house_id)

Initialize the datagen profiler.

Parameters:

Name Type Description Default
logger

Logger instance to output summaries to. If None, uses get_logger().

None
enabled bool

Whether profiling is enabled. If False, all operations are no-ops.

True

Methods:

Name Description
end

End timing an operation and record the duration.

get_episode_stats

Get current episode timing stats as a dict.

get_house_stats

Get current house timing stats as a dict.

log_episode_summary

Log a summary of timing for the current episode.

log_house_summary

Log a summary of timing for the current house (accumulated across all episodes).

log_worker_summary

Log a summary of timing for the entire worker (accumulated across all houses).

profile

Context manager for profiling a block of code.

record

Directly record a duration for an operation (useful when timing is external).

start

Start timing an operation.

Attributes:

Name Type Description
enabled
logger
Source code in molmo_spaces/utils/profiler_utils.py
def __init__(self, logger=None, enabled: bool = True) -> None:
    """
    Initialize the datagen profiler.

    Args:
        logger: Logger instance to output summaries to. If None, uses get_logger().
        enabled: Whether profiling is enabled. If False, all operations are no-ops.
    """
    self.logger = logger
    self.enabled = enabled

    # Timing storage: key -> list of durations (in seconds)
    # Episode-level: cleared after each episode summary
    self._episode_times: dict[str, list[float]] = defaultdict(list)
    # House-level: cleared after each house summary
    self._house_times: dict[str, list[float]] = defaultdict(list)
    # Worker-level (cumulative across all houses): never cleared
    self._worker_times: dict[str, list[float]] = defaultdict(list)

    # Active timers (key -> start time)
    self._active_timers: dict[str, float] = {}

    # Episode counters
    self._episode_count_in_house = 0
    self._total_episode_count = 0
    self._house_count = 0
enabled instance-attribute
enabled = enabled
logger instance-attribute
logger = logger
end
end(key: str) -> None

End timing an operation and record the duration.

Source code in molmo_spaces/utils/profiler_utils.py
def end(self, key: str) -> None:
    """End timing an operation and record the duration."""
    if not self.enabled:
        return
    if key not in self._active_timers:
        return
    duration = time.perf_counter() - self._active_timers.pop(key)
    self._episode_times[key].append(duration)
    self._house_times[key].append(duration)
    self._worker_times[key].append(duration)
get_episode_stats
get_episode_stats() -> dict[str, dict[str, float]]

Get current episode timing stats as a dict.

Source code in molmo_spaces/utils/profiler_utils.py
def get_episode_stats(self) -> dict[str, dict[str, float]]:
    """Get current episode timing stats as a dict."""
    stats = {}
    for key, values in self._episode_times.items():
        if values:
            stats[key] = {
                "total": sum(values),
                "count": len(values),
                "mean": sum(values) / len(values),
                "min": min(values),
                "max": max(values),
            }
    return stats
get_house_stats
get_house_stats() -> dict[str, dict[str, float]]

Get current house timing stats as a dict.

Source code in molmo_spaces/utils/profiler_utils.py
def get_house_stats(self) -> dict[str, dict[str, float]]:
    """Get current house timing stats as a dict."""
    stats = {}
    for key, values in self._house_times.items():
        if values:
            stats[key] = {
                "total": sum(values),
                "count": len(values),
                "mean": sum(values) / len(values),
                "min": min(values),
                "max": max(values),
            }
    return stats
log_episode_summary
log_episode_summary(episode_idx: int, house_id: int, success: bool | None = None) -> None

Log a summary of timing for the current episode.

Parameters:

Name Type Description Default
episode_idx int

Index of the episode within the house

required
house_id int

ID of the house being processed

required
success bool | None

Whether the episode was successful (optional)

None
Source code in molmo_spaces/utils/profiler_utils.py
def log_episode_summary(
    self, episode_idx: int, house_id: int, success: bool | None = None
) -> None:
    """
    Log a summary of timing for the current episode.

    Args:
        episode_idx: Index of the episode within the house
        house_id: ID of the house being processed
        success: Whether the episode was successful (optional)
    """
    if not self.enabled:
        return
    if self.logger is None:
        return

    self._episode_count_in_house += 1
    self._total_episode_count += 1

    success_str = ""
    if success is not None:
        success_str = f" success={success}"

    # Calculate episode total if we have individual components
    episode_total = 0.0
    for key in ["task_sampling", "policy_setup", "rollout_total"]:
        if key in self._episode_times:
            episode_total += sum(self._episode_times[key])

    total_str = f" episode_total={episode_total:.2f}s" if episode_total > 0 else ""

    self.logger.info(
        f"[PROFILE] Episode {episode_idx} house {house_id}{success_str}{total_str}:\n"
        + self._format_stats(self._episode_times)
    )

    # Clear episode-level times
    self._episode_times.clear()
log_house_summary
log_house_summary(house_id: int, success_count: int, total_count: int) -> None

Log a summary of timing for the current house (accumulated across all episodes).

Parameters:

Name Type Description Default
house_id int

ID of the house that was processed

required
success_count int

Number of successful episodes in this house

required
total_count int

Total number of episodes attempted in this house

required
Source code in molmo_spaces/utils/profiler_utils.py
def log_house_summary(self, house_id: int, success_count: int, total_count: int) -> None:
    """
    Log a summary of timing for the current house (accumulated across all episodes).

    Args:
        house_id: ID of the house that was processed
        success_count: Number of successful episodes in this house
        total_count: Total number of episodes attempted in this house
    """
    if not self.enabled:
        return
    if self.logger is None:
        return

    self._house_count += 1

    # Calculate some high-level stats
    house_total = sum(sum(v) for v in self._house_times.values())

    self.logger.info(
        f"[PROFILE] House {house_id} complete: {success_count}/{total_count} successful, "
        f"{self._episode_count_in_house} episodes, total_time={house_total:.2f}s\n"
        + self._format_stats(self._house_times, prefix="  House averages:\n")
    )

    # Clear house-level times and reset episode counter
    self._house_times.clear()
    self._episode_count_in_house = 0
log_worker_summary
log_worker_summary() -> None

Log a summary of timing for the entire worker (accumulated across all houses). Call this when the worker is shutting down.

Source code in molmo_spaces/utils/profiler_utils.py
def log_worker_summary(self) -> None:
    """
    Log a summary of timing for the entire worker (accumulated across all houses).
    Call this when the worker is shutting down.
    """
    if not self.enabled:
        return
    if self.logger is None:
        return

    worker_total = sum(sum(v) for v in self._worker_times.values())

    self.logger.info(
        f"[PROFILE] Worker complete: {self._house_count} houses, "
        f"{self._total_episode_count} episodes, total_time={worker_total:.2f}s\n"
        + self._format_stats(self._worker_times, prefix="  Worker averages:\n")
    )
profile
profile(key: str)

Context manager for profiling a block of code.

Source code in molmo_spaces/utils/profiler_utils.py
@contextmanager
def profile(self, key: str):
    """Context manager for profiling a block of code."""
    self.start(key)
    try:
        yield
    finally:
        self.end(key)
record
record(key: str, duration: float) -> None

Directly record a duration for an operation (useful when timing is external).

Source code in molmo_spaces/utils/profiler_utils.py
def record(self, key: str, duration: float) -> None:
    """Directly record a duration for an operation (useful when timing is external)."""
    if not self.enabled:
        return
    self._episode_times[key].append(duration)
    self._house_times[key].append(duration)
    self._worker_times[key].append(duration)
start
start(key: str) -> None

Start timing an operation.

Source code in molmo_spaces/utils/profiler_utils.py
def start(self, key: str) -> None:
    """Start timing an operation."""
    if not self.enabled:
        return
    self._active_timers[key] = time.perf_counter()

MutableFloat dataclass

MutableFloat(value: float | None = None)

Attributes:

Name Type Description
value float | None
value class-attribute instance-attribute
value: float | None = None

Profiler

Profiler(log_realtime: bool = False, save_path: str = None)

Methods:

Name Description
end
get_avg_time
get_n
print_all
profile
save_summary
start

Attributes:

Name Type Description
log_realtime
save_path
start_timestamp
Source code in molmo_spaces/utils/profiler_utils.py
def __init__(self, log_realtime: bool = False, save_path: str = None) -> None:
    self._start_time = {}
    self._end_time = {}
    self._avg_time = {}
    self._n = {}
    self.start_timestamp = time.strftime("%Y%m%d_%H%M%S")
    self.log_realtime = log_realtime
    self.save_path = None
    if self.log_realtime:
        assert save_path is not None, "save_path must be provided if log_realtime is True"
        os.makedirs(save_path, exist_ok=True)
        self.save_path = Path(save_path) / f"profiling_summary_{self.start_timestamp}.txt"
log_realtime instance-attribute
log_realtime = log_realtime
save_path instance-attribute
save_path = None
start_timestamp instance-attribute
start_timestamp = strftime('%Y%m%d_%H%M%S')
end
end(key) -> None
Source code in molmo_spaces/utils/profiler_utils.py
def end(self, key) -> None:
    self._end_time[key] = time.perf_counter()
    _time = self._end_time[key] - self._start_time[key]
    self._avg_time[key] = (self.get_avg_time(key) * self.get_n(key) + _time) / (
        self.get_n(key) + 1
    )
    self._n[key] = self.get_n(key) + 1

    if self.log_realtime:
        self._write_summary()
get_avg_time
get_avg_time(key)
Source code in molmo_spaces/utils/profiler_utils.py
def get_avg_time(self, key):
    return self._avg_time.get(key, 0)
get_n
get_n(key)
Source code in molmo_spaces/utils/profiler_utils.py
def get_n(self, key):
    return self._n.get(key, 0)
print_all
print_all() -> None
Source code in molmo_spaces/utils/profiler_utils.py
def print_all(self) -> None:
    for key in self._avg_time:
        print(f"{key}: {self._avg_time[key]}")
profile
profile(key)
Source code in molmo_spaces/utils/profiler_utils.py
@contextmanager
def profile(self, key):
    self.start(key)
    try:
        yield
    finally:
        self.end(key)
save_summary
save_summary(save_path: str) -> None
Source code in molmo_spaces/utils/profiler_utils.py
def save_summary(self, save_path: str) -> None:
    assert save_path is not None, (
        "save_path must be provided if profiler summary is being saved"
    )
    os.makedirs(save_path, exist_ok=True)
    self.save_path = Path(save_path) / f"profiling_summary_{self.start_timestamp}.txt"
    self._write_summary()
start
start(key) -> None
Source code in molmo_spaces/utils/profiler_utils.py
def start(self, key) -> None:
    self._start_time[key] = time.perf_counter()

Timer

Timer()
Source code in molmo_spaces/utils/profiler_utils.py
@contextmanager
def Timer():
    time_taken = MutableFloat()
    start = time.perf_counter()
    yield time_taken
    time_taken.value = time.perf_counter() - start

rendering_utils

Functions:

Name Description
get_geom_seg_mask

Get a mask of all geoms descended from a body in a segmentation mask.

get_geom_seg_mask

get_geom_seg_mask(model: MjModel, seg: ndarray, body_id: int) -> ndarray

Get a mask of all geoms descended from a body in a segmentation mask.

Parameters:

Name Type Description Default
model MjModel

The model to use.

required
seg ndarray

The (H, W, 2) segmentation mask, as returned by the renderer.

required
body_id int

The id of the body to get the mask for.

required

Returns:

Type Description
ndarray

np.ndarray: A (H, W) mask of the geoms descended from the body.

Source code in molmo_spaces/utils/rendering_utils.py
def get_geom_seg_mask(model: MjModel, seg: np.ndarray, body_id: int) -> np.ndarray:
    """
    Get a mask of all geoms descended from a body in a segmentation mask.

    Args:
        model (MjModel): The model to use.
        seg (np.ndarray): The (H, W, 2) segmentation mask, as returned by the renderer.
        body_id (int): The id of the body to get the mask for.

    Returns:
        np.ndarray: A (H, W) mask of the geoms descended from the body.
    """
    geoms = descendant_geoms(model, body_id)
    is_geom = seg[..., 1] == mujoco.mjtObj.mjOBJ_GEOM.value
    is_geom_of_body = np.isin(seg[..., 0], geoms)
    return is_geom & is_geom_of_body

sampler_utils

Classes:

Name Description
UniformRandomMapSampler

Functions:

Name Description
furthest_point_sampling

Furthest Point Sampling (FPS)

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

UniformRandomMapSampler

UniformRandomMapSampler(thormap: ProcTHORMap, seed: int = 0, debug: bool = False)

Methods:

Name Description
sample

Samples N points from free space on the map.

Attributes:

Name Type Description
debug
rng
thormap
Source code in molmo_spaces/utils/sampler_utils.py
def __init__(self, thormap: ProcTHORMap, seed: int = 0, debug: bool = False) -> None:
    self.thormap = thormap
    self.rng = np.random.default_rng(seed)
    self.debug = debug

    # Cache commonly used values
    self._base_occupancy = cv2.cvtColor(
        self.thormap.occupancy_map.astype(np.uint8) * 255, cv2.COLOR_GRAY2BGR
    )
    self._base_mask = np.zeros_like(self._base_occupancy)
    self._free_points = None
    self._cached_constraint_masks = {}  # Cache masks by (pos, distance) tuple

    # Pre-compute view angle constants
    self._deg_to_rad = np.pi / 180.0
debug instance-attribute
debug = debug
rng instance-attribute
rng = default_rng(seed)
thormap instance-attribute
thormap = thormap
sample
sample(N=1, positions: ndarray | None = None, quaternions: ndarray | None = None, constraint_positions: ndarray | None = None, constraint_distances: ndarray | None = None, z_pos: float | None = None, look_at: bool | None = True, camera_pose_rel_base: ndarray | None = None, view_range_deg: float | None = 30.0)

Samples N points from free space on the map. If positions and quaternions are not provided, sample uniformly from all free points. If constraint_positions and constraint_distances are provided, use them to sample points within the specified distances from the given positions.

Parameters:

Name Type Description Default
N int

The number of points to sample. Defaults to 1.

1
positions Optional[ndarray]

An array of positions to force.

None
quaternions Optional[ndarray]

An array of quaternions to force.

None
constraint_positions Optional[ndarray]

An array of positions to treat as the center of circular constraints.

None
constraint_distances Optional[ndarray]

An array of distances to treat as the radius of circular constraints.

None
z_pos Optional[float]

If specified, overrides the z-coordinate of the sampled points. Defaults to None.

None

Returns:

Type Description

np.ndarray: An array of sampled points, of shape (N, 3) if N > 1, or (3,) if N == 1.

Source code in molmo_spaces/utils/sampler_utils.py
def sample(
    self,
    N=1,
    positions: np.ndarray | None = None,
    quaternions: np.ndarray | None = None,
    constraint_positions: np.ndarray | None = None,
    constraint_distances: np.ndarray | None = None,
    z_pos: float | None = None,
    look_at: bool | None = True,
    camera_pose_rel_base: np.ndarray | None = None,
    view_range_deg: float | None = 30.0,
):
    """
    Samples N points from free space on the map.
    If positions and quaternions are not provided, sample uniformly from all free points.
    If constraint_positions and constraint_distances are provided, use them to sample points within the specified distances from the given positions.

    Args:
        N (int, optional): The number of points to sample. Defaults to 1.
        positions (Optional[np.ndarray], optional): An array of positions to force.
        quaternions (Optional[np.ndarray], optional): An array of quaternions to force.
        constraint_positions (Optional[np.ndarray], optional): An array of positions to treat as the center of circular constraints.
        constraint_distances (Optional[np.ndarray], optional): An array of distances to treat as the radius of circular constraints.
        z_pos (Optional[float], optional): If specified, overrides the z-coordinate of the sampled points. Defaults to None.

    Returns:
        np.ndarray: An array of sampled points, of shape (N, 3) if N > 1, or (3,) if N == 1.
    """
    # TODO: assert len N for positions, quaternions, constraint_positions, constraint_distances
    # TODO: positiion and constraint_positions cannot be both specified. throw error if both are provided - was this intended?
    new_positions = np.zeros((N, 3))
    new_quaternions = np.zeros((N, 4))

    # Debug visualization
    if self.debug:
        self._debug_visualize(N, constraint_positions)

    # Sample positions
    if positions is None:
        if constraint_positions is not None:
            new_positions = self._sample_constrained_positions(
                N, constraint_positions, constraint_distances
            )
            if new_positions is None:
                return None
        else:
            new_positions = self._sample_free_positions(N)
    else:
        new_positions = positions

    # Sample orientations
    if quaternions is None:
        new_quaternions = self._sample_orientations(
            N,
            new_positions,
            constraint_positions,
            look_at,
            camera_pose_rel_base,
            view_range_deg,
        )
    else:
        new_quaternions = quaternions

    # Prepare return dict
    ret = {"position": new_positions, "quaternion": new_quaternions}
    if z_pos is not None:
        ret["position"][..., 2] = z_pos

    return ret

furthest_point_sampling

furthest_point_sampling(points, k)

Furthest Point Sampling (FPS)

Parameters:

Name Type Description Default
points ndarray

Array of shape (N, D), N points in D dimensions

required
k int

Number of points to sample

required

Returns:

Name Type Description
sampled_indices ndarray

Indices of sampled points in the original array

Source code in molmo_spaces/utils/sampler_utils.py
def furthest_point_sampling(points, k):
    """
    Furthest Point Sampling (FPS)

    Args:
        points (np.ndarray): Array of shape (N, D), N points in D dimensions
        k (int): Number of points to sample

    Returns:
        sampled_indices (np.ndarray): Indices of sampled points in the original array
    """
    N, D = points.shape
    sampled_indices = np.zeros(k, dtype=int)
    distances = np.full(N, np.inf)  # distance to the closest sampled point

    # Initialize with a random point
    sampled_indices[0] = np.random.randint(0, N)

    for i in range(1, k):
        # Compute distances from the last added point
        last_point = points[sampled_indices[i - 1]]
        dist_to_last = np.linalg.norm(points - last_point, axis=1)

        # Update minimum distances
        distances = np.minimum(distances, dist_to_last)

        # Choose the point with the maximum distance to the sampled set
        sampled_indices[i] = np.argmax(distances)

    return sampled_indices

save_utils

Functions:

Name Description
batch_observations

Transpose a batch of observation dicts to a dict of batched

byte_array_to_string
convert_to_arr
dict_to_byte_array
is_camera_sensor

Determine if a sensor corresponds to a camera (RGB or depth) that produces image data.

prepare_episode_for_saving

Transform raw episode history into batched format ready for save_trajectories().

safe_to_tensor

Safely convert data to tensor, handling different dimensionalities.

save_frames_to_mp4

Save RGB frames to MP4 video file.

save_trajectories

Save trajectories in the expected hierarchical HDF5 format.

save_videos_from_raw_observations

Save videos immediately from raw observations before batch processing.

Attributes:

Name Type Description
COMPR
log

COMPR module-attribute

COMPR = 'lzf'

log module-attribute

log = getLogger(__name__)

batch_observations

batch_observations(observations: list[dict], sensor_suite: SensorSuite, device: device | None = None) -> dict[str, dict | Tensor]

Transpose a batch of observation dicts to a dict of batched observations.

Arguments

observations : List of dicts of observations. device : The torch.device to put the resulting tensors on. Will not move the tensors if None.

Returns

Transposed dict of lists of observations.

Source code in molmo_spaces/utils/save_utils.py
def batch_observations(
    observations: list[dict], sensor_suite: SensorSuite, device: torch.device | None = None
) -> dict[str, dict | torch.Tensor]:
    """Transpose a batch of observation dicts to a dict of batched
    observations.

    # Arguments

    observations :  List of dicts of observations.
    device : The torch.device to put the resulting tensors on.
        Will not move the tensors if None.

    # Returns

    Transposed dict of lists of observations.
    """

    def collect_arrays(observation: dict[str, Any]) -> dict[str, dict | list]:
        """Collect raw numpy arrays/data without converting to tensors yet."""
        if not isinstance(observation, dict):
            raise TypeError(f"Expected dict observation, got {type(observation)}: {observation}")

        batch_dict: defaultdict = defaultdict(list)

        for sensor in observation:
            if isinstance(observation[sensor], dict):
                # For nested dicts, recurse
                batch_dict[sensor] = collect_arrays(observation[sensor])
            else:
                # For leaf values, just add the raw data (don't convert to tensor yet)
                batch_dict[sensor].append(observation[sensor])

        return dict(batch_dict)

    def fill_arrays(input_batch: Any, observation: dict[str, Any]) -> None:
        """Fill batch structure with raw arrays."""
        for sensor in observation:
            if isinstance(observation[sensor], dict):
                fill_arrays(input_batch[sensor], observation[sensor])
            else:
                input_batch[sensor].append(observation[sensor])

    def stack_and_tensorize(input_batch: Any) -> None:
        """Stack numpy arrays first, then convert to tensor once (more efficient)."""
        for sensor in input_batch:
            if isinstance(input_batch[sensor], dict):
                stack_and_tensorize(input_batch[sensor])
            else:
                # Stack numpy arrays first
                data_list = input_batch[sensor]
                try:
                    # Try to stack as numpy first (much faster than stacking tensors)
                    if isinstance(data_list[0], np.ndarray):
                        stacked_numpy = np.stack(data_list, axis=0)
                        input_batch[sensor] = torch.from_numpy(stacked_numpy)
                    else:
                        # Fallback: convert each to numpy, then stack
                        numpy_arrays = [
                            np.array(d) if not isinstance(d, np.ndarray) else d for d in data_list
                        ]
                        stacked_numpy = np.stack(numpy_arrays, axis=0)
                        input_batch[sensor] = torch.from_numpy(stacked_numpy)

                    # Move to device if specified
                    if device is not None:
                        input_batch[sensor] = input_batch[sensor].to(device=device)
                except Exception as e:
                    # Fallback to old method if numpy stacking fails
                    log.warning(
                        f"Fast stacking failed for {sensor}, falling back to tensor stack: {e}"
                    )
                    input_batch[sensor] = torch.stack(
                        [
                            safe_to_tensor(batch).to(device=device)
                            if device
                            else safe_to_tensor(batch)
                            for batch in data_list
                        ],
                        dim=0,
                    )

    if len(observations) == 0:
        return cast(dict[str, dict | torch.Tensor], observations)

    observations = convert_to_arr(observations, sensor_suite)
    batch = collect_arrays(observations[0])

    for obs in observations[1:]:
        fill_arrays(batch, obs)

    stack_and_tensorize(batch)

    return cast(dict[str, dict | torch.Tensor], batch)

byte_array_to_string

byte_array_to_string(bytes_to_decode: ndarray)
Source code in molmo_spaces/utils/save_utils.py
def byte_array_to_string(bytes_to_decode: np.ndarray):
    return bytes(bytes_to_decode).rstrip(b"\x00").decode("utf-8")

convert_to_arr

convert_to_arr(observations: list[dict], sensor_suite: SensorSuite) -> list[dict]
Source code in molmo_spaces/utils/save_utils.py
def convert_to_arr(observations: list[dict], sensor_suite: SensorSuite) -> list[dict]:
    for observation in observations:
        for sensor in observation:
            if sensor_suite.sensors[sensor].is_dict:
                observation[sensor] = dict_to_byte_array(
                    observation[sensor], sensor, sensor_suite.sensors[sensor].str_max_len
                )
    return observations

dict_to_byte_array

dict_to_byte_array(target_dict, sensor_name: str, str_max_len: int)
Source code in molmo_spaces/utils/save_utils.py
def dict_to_byte_array(target_dict, sensor_name: str, str_max_len: int):
    data_string = json.dumps(target_dict, sort_keys=True, separators=(",", ":"))
    if len(data_string) > str_max_len:
        log.warning(
            f"Warning: Truncated JSON string to {str_max_len} characters for {sensor_name}. Data values may be missing."
        )
    byte_array = np.zeros(str_max_len, dtype=np.uint8)
    encoded = data_string.encode("utf-8")[:str_max_len]
    byte_array[: len(encoded)] = list(encoded)
    return byte_array

is_camera_sensor

is_camera_sensor(sensor_name: str, sensor_suite: SensorSuite | None = None) -> bool

Determine if a sensor corresponds to a camera (RGB or depth) that produces image data.

Uses sensor type metadata when available (preferred), falls back to naming heuristics for backward compatibility when sensor_suite is not provided.

Parameters:

Name Type Description Default
sensor_name str

Name of the sensor to check

required
sensor_suite SensorSuite | None

Optional SensorSuite to query for sensor type metadata

None

Returns:

Type Description
bool

True if the sensor is a camera that produces image data (RGB or depth), False otherwise.

bool

Returns False for camera parameter sensors (CameraParameterSensor) which contain

bool

metadata but not image data.

Source code in molmo_spaces/utils/save_utils.py
def is_camera_sensor(sensor_name: str, sensor_suite: SensorSuite | None = None) -> bool:
    """
    Determine if a sensor corresponds to a camera (RGB or depth) that produces image data.

    Uses sensor type metadata when available (preferred), falls back to naming heuristics
    for backward compatibility when sensor_suite is not provided.

    Args:
        sensor_name: Name of the sensor to check
        sensor_suite: Optional SensorSuite to query for sensor type metadata

    Returns:
        True if the sensor is a camera that produces image data (RGB or depth), False otherwise.
        Returns False for camera parameter sensors (CameraParameterSensor) which contain
        metadata but not image data.
    """
    # Import here to avoid circular dependency
    from molmo_spaces.env.sensors_cameras import CameraParameterSensor, CameraSensor, DepthSensor

    # Preferred approach: Check sensor type directly via SensorSuite
    if sensor_suite is not None and sensor_name in sensor_suite.sensors:
        sensor = sensor_suite.sensors[sensor_name]
        # Check if it's a CameraSensor or DepthSensor (but not CameraParameterSensor)
        return isinstance(sensor, CameraSensor | DepthSensor) and not isinstance(
            sensor, CameraParameterSensor
        )

prepare_episode_for_saving

prepare_episode_for_saving(history: dict, sensor_suite: SensorSuite, fps: float, save_dir: str | None = None, episode_idx: int = 0, save_file_suffix: str = '', remove_sensors_if_save_dir: bool = True) -> dict[str, Tensor] | None

Transform raw episode history into batched format ready for save_trajectories().

Takes the output of task.get_history() and produces a single dict with all data batched along the time dimension.

Parameters:

Name Type Description Default
history dict

Dict from task.get_history() containing: - "observations": List[List[Dict]] - [timestep][batch_idx][sensor_name] - "rewards": List[List[float]] - "terminals": List[List[bool]] - "truncateds": List[List[bool]] - "actions": List[...] (optional, currently unused) - "obs_scene": Dict (optional)

required
sensor_suite SensorSuite

SensorSuite for observation processing

required
save_dir str | None

Optional directory to save videos immediately (before batching)

None
episode_idx int

Episode index for video filenames

0
save_file_suffix str

Optional suffix for video filenames

''
remove_sensors_if_save_dir bool

remove camera-related sensors if video saved

True

Returns:

Name Type Description
dict[str, Tensor] | None

Dict[str, Tensor] with all data batched along time dimension, or None if no data

Structure dict[str, Tensor] | None
dict[str, Tensor] | None

{

Sensor observations (from batch_observations)

"qpos": Tensor(T, ...), "tcp_pose": Tensor(T, 7), ...

Episode metadata (camera data removed if videos saved)

"rewards": Tensor(T,), "terminals": Tensor(T,), "truncateds": Tensor(T,), "successes": Tensor(T,), "obs_scene": str (JSON),

dict[str, Tensor] | None

}

Source code in molmo_spaces/utils/save_utils.py
def prepare_episode_for_saving(
    history: dict,
    sensor_suite: SensorSuite,
    fps: float,
    save_dir: str | None = None,
    episode_idx: int = 0,
    save_file_suffix: str = "",
    remove_sensors_if_save_dir: bool = True,
) -> dict[str, torch.Tensor] | None:
    """
    Transform raw episode history into batched format ready for save_trajectories().

    Takes the output of task.get_history() and produces a single dict with all data
    batched along the time dimension.

    Args:
        history: Dict from task.get_history() containing:
            - "observations": List[List[Dict]] - [timestep][batch_idx][sensor_name]
            - "rewards": List[List[float]]
            - "terminals": List[List[bool]]
            - "truncateds": List[List[bool]]
            - "actions": List[...] (optional, currently unused)
            - "obs_scene": Dict (optional)
        sensor_suite: SensorSuite for observation processing
        save_dir: Optional directory to save videos immediately (before batching)
        episode_idx: Episode index for video filenames
        save_file_suffix: Optional suffix for video filenames
        remove_sensors_if_save_dir: remove camera-related sensors if video saved

    Returns:
        Dict[str, Tensor] with all data batched along time dimension, or None if no data
        Structure:
        {
            # Sensor observations (from batch_observations)
            "qpos": Tensor(T, ...),
            "tcp_pose": Tensor(T, 7),
            ...
            # Episode metadata (camera data removed if videos saved)
            "rewards": Tensor(T,),
            "terminals": Tensor(T,),
            "truncateds": Tensor(T,),
            "successes": Tensor(T,),
            "obs_scene": str (JSON),
        }
    """
    import gc

    observations_list = history.get("observations", [])

    if not observations_list or len(observations_list) == 0:
        log.info("No observation history to save")
        return None

    # Flatten batch dimension (extract first environment since batch_size=1)
    flattened_obs = [timestep_obs[0] for timestep_obs in observations_list]

    if not flattened_obs:
        log.info("No flattened observations to save")
        return None

    log.info(f"Preparing episode data: {len(flattened_obs)} timesteps")

    # Delete original observations_list to free memory immediately
    # This removes one complete copy of all episode data
    del observations_list
    history.pop("observations", None)
    gc.collect()

    # MEMORY OPTIMIZATION: Save videos BEFORE batching to avoid massive memory spike
    # Camera images are ~80% of episode memory. By saving videos now and removing
    # images from observations, we avoid creating giant tensor copies during batching.
    if save_dir is not None:
        log.debug(f"Saving videos before batching for episode {episode_idx}")
        # Use existing function to extract and save videos from raw observations
        save_videos_from_raw_observations(
            flattened_obs,
            save_dir,
            fps,
            episode_idx=episode_idx,
            save_file_suffix=save_file_suffix,
            sensor_suite=sensor_suite,
        )

        if remove_sensors_if_save_dir:
            # CRITICAL: Delete camera data (RGB and depth) from observations to avoid batching it
            # This is where the massive memory savings come from
            removed_sensors = set()
            for obs in flattened_obs:
                sensors_to_remove = []
                for sensor_name in obs:
                    # Check if this is a camera sensor (RGB or depth)
                    # Skip segmentation sensors as they're not videos
                    if is_camera_sensor(sensor_name, sensor_suite) and not sensor_name.endswith(
                        "_seg"
                    ):
                        sensors_to_remove.append(sensor_name)

                # Remove camera data
                for sensor_name in sensors_to_remove:
                    obs.pop(sensor_name, None)
                    removed_sensors.add(sensor_name)

            if removed_sensors:
                log.debug(
                    f"Removed camera sensors from observations before batching: {removed_sensors}"
                )

        gc.collect()

    # Batch observations: List[Dict] -> Dict[str, Tensor(T, ...)]
    # Note: Camera images already removed if save_dir was provided, so this is much smaller
    batched_data = batch_observations(flattened_obs, sensor_suite)

    # Delete flattened_obs after batching to free memory
    # This removes the intermediate flattened copy
    del flattened_obs
    gc.collect()

    # Add rewards if present
    if "rewards" in history:
        rewards_list = history["rewards"]
        # Flatten batch dimension directly into numpy array (avoid intermediate list)
        rewards_array = np.array(
            [timestep_reward[0] for timestep_reward in rewards_list], dtype=np.float32
        )
        batched_data["rewards"] = torch.from_numpy(rewards_array)
        del rewards_array  # Free numpy array after conversion
        history.pop("rewards", None)

    # Add terminals if present
    if "terminals" in history:
        terminals_list = history["terminals"]
        # Flatten batch dimension directly into numpy array (avoid intermediate list)
        terminals_array = np.array(
            [timestep_terminal[0] for timestep_terminal in terminals_list], dtype=bool
        )
        batched_data["terminals"] = torch.from_numpy(terminals_array)
        del terminals_array
        history.pop("terminals", None)

    # Add truncateds if present
    if "truncateds" in history:
        truncateds_list = history["truncateds"]
        # Flatten batch dimension directly into numpy array (avoid intermediate list)
        truncateds_array = np.array(
            [timestep_truncated[0] for timestep_truncated in truncateds_list], dtype=bool
        )
        batched_data["truncateds"] = torch.from_numpy(truncateds_array)
        del truncateds_array
        history.pop("truncateds", None)

    # Add successes if present
    if "successes" in history:
        successes_list = history["successes"]
        # Flatten batch dimension directly into numpy array (avoid intermediate list)
        successes_array = np.array(
            [timestep_success[0] for timestep_success in successes_list], dtype=bool
        )
        batched_data["successes"] = torch.from_numpy(successes_array)
        del successes_array
        history.pop("successes", None)

    # Add obs_scene if present
    if "obs_scene" in history:
        batched_data["obs_scene"] = json.dumps(history["obs_scene"])

    # Final GC to clean up any remaining references
    gc.collect()

    return batched_data

safe_to_tensor

safe_to_tensor(data)

Safely convert data to tensor, handling different dimensionalities.

Source code in molmo_spaces/utils/save_utils.py
def safe_to_tensor(data):
    """Safely convert data to tensor, handling different dimensionalities."""
    if isinstance(data, np.ndarray):
        # For 1D arrays (sensor data like poses, joint positions), convert directly to tensor
        if data.ndim == 1:
            return torch.from_numpy(data.copy())
        # For 2D/3D arrays, convert directly without torchvision transforms
        elif data.ndim >= 2:
            # Always preserve original format for images and other multi-dimensional data
            # This avoids torchvision's to_tensor which changes HWC->CHW and uint8->float32
            return torch.from_numpy(data.copy())
        else:
            return torch.from_numpy(data.copy())
    else:
        # Try to convert to numpy first, then to tensor
        try:
            return torch.from_numpy(np.array(data))
        except ValueError:
            # Last resort: convert to numpy array and then to tensor
            return torch.from_numpy(np.array(data, dtype=np.float32))

save_frames_to_mp4

save_frames_to_mp4(frames: Sequence[ndarray], file_path: str, fps: float, extra_kwargs: dict[str, Any] | None = None) -> None

Save RGB frames to MP4 video file.

Low-level function that assumes frames are already validated and in uint8 format. Use _save_sensor_video() for high-level saving with validation.

Source code in molmo_spaces/utils/save_utils.py
def save_frames_to_mp4(
    frames: Sequence[np.ndarray],
    file_path: str,
    fps: float,
    extra_kwargs: dict[str, Any] | None = None,
) -> None:
    """Save RGB frames to MP4 video file.

    Low-level function that assumes frames are already validated and in uint8 format.
    Use _save_sensor_video() for high-level saving with validation.
    """
    os.makedirs(os.path.dirname(os.path.abspath(file_path)), exist_ok=True)

    if not isinstance(frames, np.ndarray):
        frames = np.array(frames)

    # Frames should already be uint8 if coming from _save_sensor_video()
    if frames.dtype != np.uint8:
        log.warning(f"save_frames_to_mp4: Expected uint8, got {frames.dtype}. Converting...")
        if frames.dtype in [np.float32, np.float64]:
            frames = np.clip(frames, 0.0, 1.0)
            frames = (frames * 255).astype(np.uint8)
        else:
            frames = frames.astype(np.uint8)

    # Ensure file path has .mp4 extension
    file_path = Path(file_path)
    if file_path.suffix != ".mp4":
        file_path = file_path.with_suffix(".mp4")

    kwargs = {
        "fps": fps,
        "quality": 5,
        **(extra_kwargs if extra_kwargs is not None else {}),
    }

    # Explicitly use ffmpeg plugin to avoid imageio selecting wrong plugin (e.g., tifffile)
    try:
        # Try using ffmpeg plugin explicitly
        writer = imageio.get_writer(
            file_path, format="ffmpeg", fps=fps, quality=5, **(extra_kwargs if extra_kwargs else {})
        )
        for frame in frames:
            writer.append_data(frame)
        writer.close()
    except (ImportError, OSError, ValueError, RuntimeError) as e:
        # Fallback to mimwrite if explicit writer fails
        # Common exceptions:
        # - ImportError: ffmpeg plugin not installed
        # - OSError: file system or ffmpeg execution issues
        # - ValueError: invalid parameters
        # - RuntimeError: ffmpeg runtime errors
        log.debug(f"FFmpeg writer failed ({type(e).__name__}: {e}), falling back to mimwrite")
        # Don't pass macro_block_size parameter as it's not supported in newer imageio versions
        imageio.mimwrite(file_path, frames, format="mp4", **kwargs)

save_trajectories

save_trajectories(episodes_data: list[dict[str, Tensor]], save_dir: str, fps: float, save_file_suffix: str = '', save_mp4s: bool = True, logger: Logger | None = None) -> Path

Save trajectories in the expected hierarchical HDF5 format.

Parameters:

Name Type Description Default
episodes_data list[dict[str, Tensor]]

List of batched observations (output of batch_observations()) Each episode is a Dict[str, torch.Tensor] where tensors have shape (T, ...)

required
save_dir str

Directory to save files

required
fps float

Frames per second of episode data

required
save_file_suffix str

Optional suffix for filenames

''
save_mp4s bool

Whether to save MP4 videos

True
logger Logger | None

Optional logger to use (defaults to module logger)

None

Expected structure: traj_N/ ├── obs/ │ ├── agent/ │ │ ├── qpos (T,str_max_len) │ │ └── qvel (T,str_max_len) │ ├── extra/ │ │ ├── obj_start (T,7) │ │ ├── obj_end (T,7) │ │ ├── tcp_pose (T,7) │ │ ├── grasp_pose (T,7) │ │ ├── robot_base_pose (T,7) │ │ └── door_state (T,str_max_len) │ │ ├── joint_angle │ │ ├── opening_percentage │ │ ├── handle_position │ │ ├── handle_extents │ │ ├── door_position │ │ └── is_open │ ├── sensor_param/ │ │ └── render_camera/ │ │ ├── extrinsic_cv (T,3,4) │ │ ├── cam2world_gl (T,4,4) │ │ └── intrinsic_cv (T,3,3) │ └── sensor_data/ │ └── render_camera/ │ ├── rgb (T,str_max_len) - video path │ ├── depth (T,str_max_len) - video path │ └── segmentation (T,str_max_len) - video path ├── actions (T,str_max_len) - flattened ├── extra/ - original formats for reference └── episode metadata...

Source code in molmo_spaces/utils/save_utils.py
def save_trajectories(
    episodes_data: list[dict[str, torch.Tensor]],
    save_dir: str,
    fps: float,
    save_file_suffix: str = "",
    save_mp4s: bool = True,
    logger: logging.Logger | None = None,
) -> Path:
    """
    Save trajectories in the expected hierarchical HDF5 format.

    Args:
        episodes_data: List of batched observations (output of batch_observations())
                      Each episode is a Dict[str, torch.Tensor] where tensors have shape (T, ...)
        save_dir: Directory to save files
        fps: Frames per second of episode data
        save_file_suffix: Optional suffix for filenames
        save_mp4s: Whether to save MP4 videos
        logger: Optional logger to use (defaults to module logger)

    Expected structure:
    traj_N/
    ├── obs/
    │   ├── agent/
    │   │   ├── qpos (T,str_max_len)
    │   │   └── qvel (T,str_max_len)
    │   ├── extra/
    │   │   ├── obj_start (T,7)
    │   │   ├── obj_end (T,7)
    │   │   ├── tcp_pose (T,7)
    │   │   ├── grasp_pose (T,7)
    │   │   ├── robot_base_pose (T,7)
    │   │   └── door_state (T,str_max_len)
    │   │       ├── joint_angle
    │   │       ├── opening_percentage
    │   │       ├── handle_position
    │   │       ├── handle_extents
    │   │       ├── door_position
    │   │       └── is_open
    │   ├── sensor_param/
    │   │   └── render_camera/
    │   │       ├── extrinsic_cv (T,3,4)
    │   │       ├── cam2world_gl (T,4,4)
    │   │       └── intrinsic_cv (T,3,3)
    │   └── sensor_data/
    │       └── render_camera/
    │           ├── rgb (T,str_max_len) - video path
    │           ├── depth (T,str_max_len) - video path
    │           └── segmentation (T,str_max_len) - video path
    ├── actions (T,str_max_len) - flattened
    ├── extra/ - original formats for reference
    └── episode metadata...
    """
    logger = logger or log

    os.makedirs(save_dir, exist_ok=True)

    # Save HDF5 file
    hdf5_path = os.path.join(save_dir, f"trajectories{save_file_suffix}.h5")

    with h5py.File(hdf5_path, "w") as hdf5_file:
        for episode_idx, episode_data in enumerate(episodes_data):
            episode_group = hdf5_file.create_group(f"traj_{episode_idx}")

            logger.debug(
                f"\n[SAVE_UTILS DEBUG] Processing episode {episode_idx} with batched observations"
            )
            logger.debug(f"[SAVE_UTILS DEBUG] Available sensors: {list(episode_data.keys())}")

            # Create main obs group with expected structure
            obs_group = episode_group.create_group("obs")

            # Save agent data (qpos, qvel)
            _save_agent_data_from_batched(obs_group, episode_data)

            # Save extra data (pose sensors)
            _save_extra_data_from_batched(obs_group, episode_data)

            # Save sensor parameters (camera params)
            _save_sensor_params_from_batched(obs_group, episode_data)

            # Save sensor data (camera structure)
            _save_sensor_data_from_batched(
                obs_group, episode_data, episode_idx, save_dir, save_file_suffix, logger
            )

            # Save actions
            _save_actions_from_batched(episode_group, episode_data, logger)

            # Save environment states
            _save_env_states_from_batched(episode_group, episode_data)

            # Create placeholder episode metadata (since we don't have this in batched format)
            num_timesteps = len(episode_data["qpos"])

            if "terminateds" in episode_data:
                terminated_array = episode_data["terminateds"]
            else:
                logger.warning("No terminated recorded, assuming episode ended")
                # Create dummy terminated/truncated arrays - assume last step is terminal
                terminated_array = np.zeros(num_timesteps, dtype=bool)
                terminated_array[-1] = True  # Assume episode ended
            episode_group.create_dataset("terminated", data=terminated_array, compression=COMPR)

            if "truncateds" in episode_data:
                truncated_array = episode_data["truncateds"]
            else:
                logger.warning("No truncateds recorded, assuming untruncated")
                truncated_array = np.zeros(num_timesteps, dtype=bool)
            episode_group.create_dataset("truncated", data=truncated_array, compression=COMPR)

            # Create dummy rewards (could be made configurable)
            if "rewards" in episode_data:
                rewards_array = episode_data["rewards"]
            else:
                logger.warning("No rewards recorded, assuming success")
                rewards_array = np.zeros(num_timesteps, dtype=np.float32)
                rewards_array[-1] = 1.0  # Reward at end
            episode_group.create_dataset("rewards", data=rewards_array, compression=COMPR)

            # Save obs_scene
            if "obs_scene" in episode_data:
                obs_scene = episode_data["obs_scene"]
            else:
                logger.warning("No obs_scene recorded, using default")
                obs_scene = json.dumps({})
            episode_group.create_dataset(
                "obs_scene", data=obs_scene
            )  # don't compress scalar dataset

            # Success and fail arrays
            if "successes" in episode_data:
                success_array = episode_data["successes"]
            else:
                logger.warning("No successes recorded, assuming success at end")
                success_array = np.zeros(num_timesteps, dtype=bool)
                success_array[-1] = True  # Assume success

            fail_array = ~success_array
            episode_group.create_dataset("success", data=success_array, compression=COMPR)
            episode_group.create_dataset("fail", data=fail_array, compression=COMPR)

    logger.info(f"Saved {len(episodes_data)} episodes to: {os.path.abspath(save_dir)}")
    # logger.info(f"  HDF5 file: {hdf5_path}")

    # Videos should have been saved during prepare_episode_for_saving() before batching
    # This is required for memory optimization - camera data is removed before batching
    if save_mp4s:
        if len(episodes_data) > 0:
            # Verify that camera data was removed (indicates videos were already saved)
            first_episode = episodes_data[0]
            camera_sensors_in_batch = [s for s in first_episode if is_camera_sensor(s)]

            if camera_sensors_in_batch:
                raise RuntimeError(
                    f"Camera data still present in batched episodes: {camera_sensors_in_batch}. "
                    f"Videos must be saved via save_videos_from_raw_observations() before batching. "
                    f"Pass save_dir to prepare_episode_for_saving() to enable this."
                )

        logger.debug("Videos were saved during prepare_episode_for_saving() (before batching)")
    return Path(hdf5_path)

save_videos_from_raw_observations

save_videos_from_raw_observations(observations_list, save_dir, fps, episode_idx=0, save_file_suffix='', sensor_suite: SensorSuite | None = None) -> None

Save videos immediately from raw observations before batch processing. This avoids the corruption that happens during batch_observations tensor conversion.

Parameters:

Name Type Description Default
observations_list

List of raw observation dicts from episode steps

required
save_dir

Directory to save videos

required
fps

Frames per second of episode data

required
episode_idx

Episode index for naming

0
save_file_suffix

Optional suffix for filenames

''
sensor_suite SensorSuite | None

Optional SensorSuite for proper sensor type detection

None
Source code in molmo_spaces/utils/save_utils.py
def save_videos_from_raw_observations(
    observations_list,
    save_dir,
    fps,
    episode_idx=0,
    save_file_suffix="",
    sensor_suite: SensorSuite | None = None,
) -> None:
    """
    Save videos immediately from raw observations before batch processing.
    This avoids the corruption that happens during batch_observations tensor conversion.

    Args:
        observations_list: List of raw observation dicts from episode steps
        save_dir: Directory to save videos
        fps: Frames per second of episode data
        episode_idx: Episode index for naming
        save_file_suffix: Optional suffix for filenames
        sensor_suite: Optional SensorSuite for proper sensor type detection
    """
    os.makedirs(save_dir, exist_ok=True)

    if not observations_list:
        log.warning("No observations to save videos from")
        return

    # Find all camera sensors in the first observation
    camera_sensors = {}
    for sensor_name in observations_list[0]:
        if not is_camera_sensor(sensor_name, sensor_suite):
            continue

        # Categorize by type
        # Check depth FIRST (before RGB) to correctly identify depth cameras
        # e.g., "wrist_camera_depth" should be "depth" not "rgb"
        if sensor_name.endswith("_depth"):
            camera_sensors[sensor_name] = "depth"
        # Skip segmentation sensors (not saved as videos)
        elif sensor_name.endswith("_seg"):
            continue
        else:
            # All other camera sensors are RGB
            camera_sensors[sensor_name] = "rgb"

    log.debug(
        f"INFO: Saving videos for episode {episode_idx} with {len(camera_sensors)} cameras "
        f"({sum(1 for t in camera_sensors.values() if t == 'rgb')} RGB, "
        f"{sum(1 for t in camera_sensors.values() if t == 'depth')} depth) "
        f"and {len(observations_list)} frames"
    )
    log.debug(f"Camera sensors detected: {camera_sensors}")

    # Extract and save video for each camera
    for sensor_name, sensor_type in camera_sensors.items():
        # Extract frames from all observations for this camera
        frames = []
        for obs in observations_list:
            if sensor_name in obs:
                frame_data = obs[sensor_name]
                if isinstance(frame_data, np.ndarray):
                    # Remove batch dimension if present
                    if frame_data.ndim == 4 and frame_data.shape[0] == 1:
                        frame_data = frame_data[0]
                    frames.append(frame_data)

        if frames:
            # Generate video path
            video_path = os.path.join(
                save_dir, f"episode_{episode_idx:08d}_{sensor_name}{save_file_suffix}.mp4"
            )

            # Use unified video saving function (handles all validation and conversion)
            _save_sensor_video(
                sensor_name=sensor_name,
                sensor_type=sensor_type,
                frames=frames,
                video_path=video_path,
                fps=fps,
                logger=log,
            )

            log.debug(f"SUCCESS: Saved {sensor_type} video: {video_path} ({len(frames)} frames)")
        else:
            log.warning(f"WARNING: No frames found for camera {sensor_name}")

scene_maps

Classes:

Name Description
ProcTHORMap
THORMap

Map of the Mujoco scene.

iTHORMap

Functions:

Name Description
circular_kernel
sample_around_point

Sample a 2D point around a given point within a given radius.

Attributes:

Name Type Description
dir_path
free_points
free_points_px
ithormap
log
one_room_map
procthormap
procthormap_loaded
room_map
run_ithor_map_generation
run_procthor_map_generation
xmls

dir_path module-attribute

dir_path = f'{ASSETS_DIR}/scenes/procthor-10k-train'

free_points module-attribute

free_points = get_free_points_by_room('room|2')

free_points_px module-attribute

free_points_px = pos_m_to_px(free_points[0])

ithormap module-attribute

ithormap = from_mj_model_path(model_path, agent_radius=0.25, px_per_m=200, device_id=None)

log module-attribute

log = getLogger(__name__)

one_room_map module-attribute

one_room_map = room_map

procthormap module-attribute

procthormap = from_mj_model_path(model_path, agent_radius=None, px_per_m=200, device_id=None)

procthormap_loaded module-attribute

procthormap_loaded = load(replace('.xml', '_map.png'))

room_map module-attribute

room_map = _room_map

run_ithor_map_generation module-attribute

run_ithor_map_generation = True

run_procthor_map_generation module-attribute

run_procthor_map_generation = True

xmls module-attribute

xmls = glob(join(dir_path, 'train_1.xml'))

ProcTHORMap

ProcTHORMap(occupancy: ndarray, world_to_map: ndarray, map_to_world: ndarray, px_per_m: int, room_map: ndarray = None, room_ids_to_name: dict = None, use_filament: bool = False)

Bases: THORMap

Methods:

Name Description
__call__
check_collision
from_mj_model_path

Generate a ProcTHORMap from a MuJoCo model with the open door path cleared.

get_free_points
get_free_points_by_room
load
pos_m_to_px
pos_px_to_m
safe_model_data
save
save_map

Attributes:

Name Type Description
MAP_TYPES
map_to_world
occupancy
occupancy_map
occupancy_scale_factor
occupancy_world_dims
px_per_m
room_ids_to_name
room_map
room_names_to_id
voxel_map
voxel_scale_to_world
world_to_map
Source code in molmo_spaces/utils/scene_maps.py
def __init__(
    self,
    occupancy: np.ndarray,
    world_to_map: np.ndarray,
    map_to_world: np.ndarray,
    px_per_m: int,
    room_map: np.ndarray = None,
    room_ids_to_name: dict = None,
    use_filament: bool = False,
):
    super().__init__(occupancy_map=occupancy, px_per_m=px_per_m, use_filament=use_filament)
    self.occupancy = occupancy
    self._room_map = room_map
    self.room_ids_to_name = room_ids_to_name
    if room_ids_to_name is not None:
        self.room_names_to_id = {v: k for k, v in room_ids_to_name.items()}
    else:
        self.room_names_to_id = None
    self.world_to_map = world_to_map
    self.map_to_world = map_to_world
MAP_TYPES class-attribute instance-attribute
MAP_TYPES = ['occupancy', 'voxel']
map_to_world instance-attribute
map_to_world = map_to_world
occupancy instance-attribute
occupancy = occupancy
occupancy_map property
occupancy_map
occupancy_scale_factor property
occupancy_scale_factor
occupancy_world_dims property
occupancy_world_dims
px_per_m property
px_per_m
room_ids_to_name instance-attribute
room_ids_to_name = room_ids_to_name
room_map property
room_map
room_names_to_id instance-attribute
room_names_to_id = {v: k for k, v in (items())}
voxel_map property
voxel_map
voxel_scale_to_world property
voxel_scale_to_world
world_to_map instance-attribute
world_to_map = world_to_map
__call__
__call__(r, c, map_type: str = 'occupancy')
Source code in molmo_spaces/utils/scene_maps.py
def __call__(self, r, c, map_type: str = "occupancy"):
    if map_type == "occupancy":
        position = np.zeros(3)
        h, w = self.occupancy_map.shape
        position[0] = self._occupancy_scale_factor * (c - w / 2)
        position[1] = self._occupancy_scale_factor * (r - h / 2)
        return position
    elif map_type == "voxel":
        raise NotImplementedError("Voxel map is not implemented yet")
check_collision
check_collision(pos: ndarray) -> bool | ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def check_collision(self, pos: np.ndarray) -> bool | np.ndarray:
    pos_px = self.pos_m_to_px(pos)
    in_range_mask = np.all((pos_px >= 0) & (pos_px < self.occupancy.shape), axis=1)
    ret = in_range_mask  # np.empty(len(pos), dtype=bool)
    ret[in_range_mask] = self.occupancy[pos_px[in_range_mask, 0], pos_px[in_range_mask, 1]]
    # ret[~in_range_mask] = True
    return ret
from_mj_model_path classmethod
from_mj_model_path(model_path: str, camera: str | None = None, agent_radius: float | None = None, px_per_m: int = 100, data: MjData | None = None, device_id: int = None, use_filament: bool = False)

Generate a ProcTHORMap from a MuJoCo model with the open door path cleared.

This method renders occupancy maps at three camera heights
  • 5.0 m: Base map with full wall geometry.
  • 2.5 m and 1.5 m: Lower views that capture the door opening, since walls might not be visible at these heights.

It computes a door mask as the area that is occupied at 2.5 m but free at 1.5 m and applies that mask to the 5.0 m map. The method also computes the transformation matrices for mapping between world and map coordinates.

Returns:

Name Type Description
ProcTHORMap

An instance with the occupancy map having the door path cleared.

Source code in molmo_spaces/utils/scene_maps.py
@classmethod
def from_mj_model_path(
    cls,
    model_path: str,
    camera: str | None = None,
    agent_radius: float | None = None,
    px_per_m: int = 100,
    data: MjData | None = None,
    device_id: int = None,
    use_filament: bool = False,
):
    """
    Generate a ProcTHORMap from a MuJoCo model with the open door path cleared.

    This method renders occupancy maps at three camera heights:
      - 5.0 m: Base map with full wall geometry.
      - 2.5 m and 1.5 m: Lower views that capture the door opening,
        since walls might not be visible at these heights.

    It computes a door mask as the area that is occupied at 2.5 m but free
    at 1.5 m and applies that mask to the 5.0 m map. The method also computes
    the transformation matrices for mapping between world and map coordinates.

    Returns:
      ProcTHORMap: An instance with the occupancy map having the door path cleared.
    """
    # If no simulation data provided, initialize MjData and run forward
    spec = mujoco.MjSpec.from_file(model_path)

    # Recursively collect all ceiling geoms from all bodies
    ceiling_geoms = []

    def collect_ceiling_geoms_recursively(body_spec: mujoco.MjsBody) -> None:
        """Recursively traverse all bodies and collect ceiling geoms."""
        # Check geoms in current body
        for geom in body_spec.geoms:
            geom_name = geom.name
            if geom_name and "ceiling" in geom_name.lower():
                ceiling_geoms.append(geom)

        # Recursively check child bodies
        for child_body in body_spec.bodies:
            collect_ceiling_geoms_recursively(child_body)

    # Start recursion from worldbody
    collect_ceiling_geoms_recursively(spec.worldbody)

    # Delete all collected ceiling geoms
    for geom in ceiling_geoms:
        log.debug(f"[ProcTHORMap] Deleting ceiling geom: {geom.name}")
        spec.delete(geom)  # for mujoco>3.3.5

    model, data = cls.safe_model_data(spec, data)

    # Identify floor geom indices
    floor_ids = []
    room_ids_to_name = {}
    for geom_id in range(model.ngeom):
        geom_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, geom_id)
        if geom_name and (geom_name.startswith("room|") or geom_name.startswith("room_")):
            floor_ids.append(geom_id)
            room_body_id = model.geom(geom_id).bodyid.item()
            room_body_name = model.body(room_body_id).name
            assert (
                room_body_name
                and room_body_name.startswith("world")
                or room_body_name.startswith("room_")
            ), "Room body name must start with 'world' or 'room_'"
            room_ids_to_name[geom_id + 1] = room_body_name  # 0 is background

    assert len(floor_ids) > 0, "No floors found in the model"

    # identify opened doors body indices
    parent_to_child = {}
    parent_names = []
    for body_id in range(model.nbody):
        # body_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_BODY, body_id)
        root_body = model.body(model.body(body_id).rootid.item())
        root_body_id = root_body.id
        root_body_name = root_body.name
        if root_body_name and (
            root_body_name.startswith("door_")
            or root_body_name.startswith("doorway_")
            or root_body_name.startswith("doorframe_")
        ):
            if root_body_id not in parent_to_child:
                parent_to_child[root_body_id] = []
            parent_to_child[root_body_id].append(body_id)
            parent_names.append(root_body_name)
    ### NOTE MAP REQUIRE DOOR HAS JOINTS
    door_ids = []
    doorway_ids = []
    for root_body_id, children in parent_to_child.items():
        root_body_name = model.body(root_body_id).name
        for door_id in children:
            door = model.body(door_id)
            # door_name = model.body(door_id).name
            jntadr = door.jntadr.item()
            if (
                jntadr >= 0
                and model.joint(jntadr).type == mujoco.mjtJoint.mjJNT_HINGE
                and model.joint(jntadr).qpos0.item() != 0.0
            ):
                door_ids.append(door_id)
                doorway_ids.extend(children)  # body_id)
            if jntadr < 0:
                # door without joint are always open or closed
                # closed doors have 3 children bodies (frame, door and handle)
                if len(children) == 2:  # itself, and frame
                    doorway_ids.append(door_id)  # body_id)

    doorframe_geom_ids = []
    door_geom_ids = []
    for geom_id in range(model.ngeom):
        # geom_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, geom_id)
        body_id = model.geom(geom_id).bodyid.item()
        parent_body_id = model.body(body_id).parentid.item()
        if body_id in door_ids or parent_body_id in door_ids:
            door_geom_ids.append(geom_id)
        parent_body_id = model.body(body_id).rootid.item()
        if parent_body_id in doorway_ids:
            doorframe_geom_ids.append(geom_id)

    # Compute axis-aligned bounding box (AABB) for floors and add a 1 m buffer per side
    aabb_center, aabb_size = geom_aabb(model, data, floor_ids, tight_mesh=False)
    aabb_size += np.array([2, 2, 0])

    # Helper function to render occupancy map at a given camera height.
    # When cam_distance == 5.0, it also returns the cam_to_world transform.
    def render_occupancy(cam_distance: float):
        cam = mujoco.MjvCamera()
        cam.type = mujoco.mjtCamera.mjCAMERA_FREE
        cam.lookat[:] = aabb_center
        cam.distance = cam_distance
        cam.azimuth = 0
        cam.elevation = -90
        cam.orthographic = 1

        h = round(px_per_m * aabb_size[0])
        w = round(px_per_m * aabb_size[1])
        effective_px = h / aabb_size[0]

        renderer = _get_renderer(
            model, width=w, height=h, device_id=device_id, use_filament=use_filament
        )
        renderer.update(data, cam)
        for camera in renderer.scene.camera:
            camera: mujoco.MjvGLCamera
            camera.orthographic = 1
            camera.frustum_bottom = -aabb_size[0] / 2
            camera.frustum_top = aabb_size[0] / 2

        renderer.enable_segmentation_rendering()
        seg = renderer.render()
        seg_geom = seg[..., 0]
        # seg_body = seg[..., 2]
        cam_to_world = None
        if cam_distance == 5.0:
            # Extract camera-to-world transformation from the first camera in the scene.
            cam_to_world = np.eye(4)
            cam_to_world[:3, 3] = renderer.scene.camera[0].pos
            camera_x_ax = np.cross(
                renderer.scene.camera[0].up, -renderer.scene.camera[0].forward
            )
            cam_to_world[:3, :3] = np.column_stack(
                (camera_x_ax, renderer.scene.camera[0].up, -renderer.scene.camera[0].forward)
            )
            assert np.allclose(cam_to_world[:3, 2], [0, 0, 1]), (
                "Camera must be pointing straight down"
            )
        renderer.close()

        # Mark obstacles as False, free space by geom id
        occ_room_floor = np.zeros_like(seg_geom, dtype=int)
        for fid in floor_ids:
            occ_room_floor[seg_geom == fid] = fid + 1  # 0 is background
        # cv2.imwrite(f"occ_room_floor_{px_per_m}.png", occ_room_floor*20)

        # Assemble occupancy map: mark free regions as False, obstacles as True.
        occ_floor = np.ones_like(seg_geom, dtype=bool)
        for fid in floor_ids:
            occ_floor &= seg_geom != fid
        # cv2.imwrite(f"occ_floor.png", occ_floor*255)

        # mask of doors only
        occ_door = np.zeros_like(seg_geom, dtype=bool)
        for did in door_geom_ids:
            occ_door[seg_geom == did] = True
        # cv2.imwrite(f"occ_door_{px_per_m}.png", occ_door*255)

        # mask of doorframe + doors
        occ_doorframe = np.zeros_like(seg_geom, dtype=bool)
        for did in doorframe_geom_ids:
            occ_doorframe[seg_geom == did] = True
        # cv2.imwrite(f"occ_doorframe_{px_per_m}.png", occ_doorframe*255)

        # remove door from doorframe
        occ_door_path = occ_doorframe & ~occ_door
        occ_door_path = cv2.dilate(occ_door_path.astype(np.uint8), circular_kernel(15)).astype(
            bool
        )
        # cv2.imwrite(f"occ_door_path_dilated_{px_per_m}.png", occ_door_path*255)

        # remove door path from occupied map
        occ = occ_floor
        occ[occ_door_path == 1] = False
        # cv2.imwrite(f"occ_final_{px_per_m}.png", occ*255)

        if cam_distance == 5.0:
            return occ, occ_room_floor, effective_px, (h, w), cam_to_world

        return occ, occ_room_floor, effective_px, (h, w)

    ### TODO: is this treating all doors as open?
    # Might do need to render at different heights and compare
    occ_map_5, occ_room_floor_map_5, effective_px, (h, w), cam_to_world = render_occupancy(5.0)
    # cv2.imwrite("occ_map_5.png", occ_map_5*255)

    # Apply the door mask to the base map from 5.0 m: free those regions.
    occ_final = occ_map_5.copy()
    # $occ_final[door_mask] = False
    # cv2.imwrite("occ_final.png", occ_final * 255)

    occ_room_floor_final = occ_room_floor_map_5.copy()
    if agent_radius is not None:
        rad_px = int(agent_radius * effective_px)
        kernel = circular_kernel(rad_px)
        occ_final = cv2.dilate(occ_final.astype(np.uint8), kernel).astype(bool)
        occ_room_floor_final[occ_final] = 0

    # Compute transformation matrices based on the 5.0 m rendering.
    cam_to_map = np.array([[0, -effective_px, 0, h / 2], [effective_px, 0, 0, w / 2]])
    world_to_map = cam_to_map @ inverse_homogeneous_matrix(cam_to_world)

    map_to_centered = np.array([[0, 1, -w / 2], [-1, 0, h / 2], [0, 0, 1]])
    centered_to_cam = np.array([[1 / effective_px, 0, 0], [0, 1 / effective_px, 0], [0, 0, 1]])
    cam_to_world_floor = cam_to_world[:-1, [0, 1, 3]].copy()
    cam_to_world_floor[2, 2] = 0
    map_to_world = cam_to_world_floor @ centered_to_cam @ map_to_centered

    # Create a new ProcTHORMap instance with the door-open occupancy map.
    px_per_m = effective_px

    # Flip the occupancy map
    occ_final = ~occ_final  # so free space is True, occupied space is False

    instance = cls(
        occupancy=occ_final,
        room_map=occ_room_floor_final,
        room_ids_to_name=room_ids_to_name,
        world_to_map=world_to_map,
        map_to_world=map_to_world,
        px_per_m=px_per_m,
    )
    # Optionally, store the original (base) occupancy map for reference.
    instance.occupancy_base = occ_map_5

    # Explicitly delete temporary MuJoCo objects before garbage collection
    # These were created for map generation and are no longer needed
    del model
    del data

    # Force garbage collection to free MuJoCo objects
    gc.collect()

    return instance
get_free_points
get_free_points() -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
def get_free_points(self) -> np.ndarray:
    free_points_px = np.argwhere(self.occupancy)
    return self.pos_px_to_m(free_points_px)
get_free_points_by_room
get_free_points_by_room(room_key: str) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
def get_free_points_by_room(self, room_key: str) -> np.ndarray:
    room_id = self.room_names_to_id[room_key]
    free_points_px = np.argwhere(self.occupancy)
    free_points_px = free_points_px[
        self.room_map[free_points_px[:, 0], free_points_px[:, 1]] == room_id
    ]
    return self.pos_px_to_m(free_points_px)
load classmethod
load(path: str, agent_radius: float | None = None)
Source code in molmo_spaces/utils/scene_maps.py
@classmethod
def load(cls, path: str, agent_radius: float | None = None):
    if path.endswith(".png"):
        # stacked images
        all_img = Image.open(path)
        # first channel is occupancy, second channel is room map
        img = np.array(all_img)[:, :, 0]
        room_map = np.array(all_img)[:, :, 2]

        world_to_map = np.array(json.loads(all_img.info["world_to_map"]))
        map_to_world = np.array(json.loads(all_img.info["map_to_world"]))
        px_per_m = int(np.ceil(json.loads(all_img.info["px_per_m"])))
        room_ids_to_name = json.loads(all_img.info["room_ids_to_name"])
        room_ids_to_name = {int(k): v for k, v in room_ids_to_name.items()}
        occupancy = np.array(img) > 0
        room_map = np.array(room_map)

        if agent_radius is not None:
            occupancy = ~occupancy
            rad_px = int(agent_radius * px_per_m)
            kernel = circular_kernel(rad_px)
            occupancy = cv2.dilate(occupancy.astype(np.uint8), kernel).astype(bool)
            room_map[occupancy] = 0
            occupancy = ~occupancy

        return cls(
            occupancy=occupancy,
            room_map=room_map,
            room_ids_to_name=room_ids_to_name,
            world_to_map=world_to_map,
            map_to_world=map_to_world,
            px_per_m=px_per_m,
        )
    elif path.endswith(".npz"):
        data = np.load(path)

        world_to_map = data["world_to_map"]
        map_to_world = data["map_to_world"]
        px_per_m = data["px_per_m"]
        room_ids_to_name = data["room_ids_to_name"]
        occupancy = data["occupancy"]
        room_map = data["occupancy"]

        if agent_radius is not None:
            occupancy = ~occupancy
            rad_px = int(agent_radius * px_per_m)
            kernel = circular_kernel(rad_px)
            occupancy = cv2.dilate(occupancy.astype(np.uint8), kernel).astype(bool)
            room_map[occupancy] = 0
            occupancy = ~occupancy

        return cls(
            occupancy=occupancy,
            room_map=room_map,
            room_ids_to_name=data["room_ids_to_name"],
            world_to_map=data["world_to_map"],
            map_to_world=data["map_to_world"],
            px_per_m=data["px_per_m"],
        )
    else:
        raise ValueError(f"Unsupported file format: {path}")
pos_m_to_px
pos_m_to_px(pos_m: ndarray) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def pos_m_to_px(self, pos_m: np.ndarray) -> np.ndarray:
    assert pos_m.ndim == 2 and pos_m.shape[-1] == 3
    return np.round(homogenize(pos_m) @ self.world_to_map.T).astype(int)
pos_px_to_m
pos_px_to_m(pos_px: ndarray) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def pos_px_to_m(self, pos_px: np.ndarray) -> np.ndarray:
    assert pos_px.ndim == 2 and pos_px.shape[-1] == 2
    return homogenize(pos_px) @ self.map_to_world.T
safe_model_data staticmethod
safe_model_data(spec, data=None)
Source code in molmo_spaces/utils/scene_maps.py
@staticmethod
def safe_model_data(spec, data=None):
    # Delete bodies that match blacklisted asset UIDs (prevents compile errors)
    _delete_blacklisted_bodies(spec)

    # Create new model and data
    try:
        model = spec.compile()
    except ValueError as e:
        _handle_compile_error_and_blacklist(e)
        raise
    finally:
        del spec  # Explicitly free the spec

    if data is None:
        data = mujoco.MjData(model)
        mujoco.mj_forward(model, data)

    return model, data
save
save(path: str)
Source code in molmo_spaces/utils/scene_maps.py
def save(self, path: str):
    if path.endswith(".png"):
        # img = Image.fromarray(self.occupancy.astype(np.uint8) * 255)
        # room_map_img = Image.fromarray(self.room_map.astype(np.uint8))

        # stack the two images as channel
        img = self.occupancy.astype(np.uint8) * 255
        room_map_img = self.room_map.astype(np.uint8)
        all_img = np.stack([img, img, room_map_img], axis=2)
        all_img = Image.fromarray(all_img.astype(np.uint8))

        metadata = PngInfo()
        metadata.add_text("world_to_map", json.dumps(self.world_to_map.tolist()))
        metadata.add_text("map_to_world", json.dumps(self.map_to_world.tolist()))
        metadata.add_text("px_per_m", json.dumps(self.px_per_m))
        metadata.add_text("room_ids_to_name", json.dumps(self.room_ids_to_name))
        all_img.save(path, pnginfo=metadata)
    elif path.endswith(".npz"):
        np.savez(
            path,
            occupancy=self.occupancy,
            room_map=self.room_map,
            room_ids_to_name=self.room_ids_to_name,
            world_to_map=self.world_to_map,
            map_to_world=self.map_to_world,
            px_per_m=self.px_per_m,
        )
    else:
        raise ValueError(f"Unsupported file format: {path}")
save_map
save_map(path)
Source code in molmo_spaces/utils/scene_maps.py
def save_map(self, path):
    if self.occupancy_map is not None:
        assert path.endswith(".png"), "Only PNG format is supported"
        cv2.imwrite(path, self.occupancy_map)
    if self.voxel_map is not None:
        assert path.endswith(".npy"), "Only NPY format is supported"
        # np.save(path, self.voxel_map)
        raise NotImplementedError("Voxel map is not implemented yet")

THORMap

THORMap(occupancy_map=None, occupancy_scale_factor=None, occupancy_world_dims=None, voxel_map=None, voxel_scale_factor=None, px_per_m: int = 100, use_filament: bool = False)

Map of the Mujoco scene. including fixed, hinged/articulatable, and free objects. exclusing dynamic agent

Methods:

Name Description
__call__
save_map

Attributes:

Name Type Description
MAP_TYPES
occupancy_map
occupancy_scale_factor
occupancy_world_dims
voxel_map
voxel_scale_to_world
Source code in molmo_spaces/utils/scene_maps.py
def __init__(
    self,
    occupancy_map=None,
    occupancy_scale_factor=None,
    occupancy_world_dims=None,
    voxel_map=None,
    voxel_scale_factor=None,
    px_per_m: int = 100,
    use_filament: bool = False,
):
    self._occupancy_map = occupancy_map
    self._occupancy_scale_factor = occupancy_scale_factor
    self._occupancy_world_dims = occupancy_world_dims
    self._voxel_map = voxel_map
    self._voxel_scale_factor = voxel_scale_factor
    self._px_per_m = px_per_m
    self._use_filament = use_filament
MAP_TYPES class-attribute instance-attribute
MAP_TYPES = ['occupancy', 'voxel']
occupancy_map property
occupancy_map
occupancy_scale_factor property
occupancy_scale_factor
occupancy_world_dims property
occupancy_world_dims
voxel_map property
voxel_map
voxel_scale_to_world property
voxel_scale_to_world
__call__
__call__(r, c, map_type: str = 'occupancy')
Source code in molmo_spaces/utils/scene_maps.py
def __call__(self, r, c, map_type: str = "occupancy"):
    if map_type == "occupancy":
        position = np.zeros(3)
        h, w = self.occupancy_map.shape
        position[0] = self._occupancy_scale_factor * (c - w / 2)
        position[1] = self._occupancy_scale_factor * (r - h / 2)
        return position
    elif map_type == "voxel":
        raise NotImplementedError("Voxel map is not implemented yet")
save_map
save_map(path)
Source code in molmo_spaces/utils/scene_maps.py
def save_map(self, path):
    if self.occupancy_map is not None:
        assert path.endswith(".png"), "Only PNG format is supported"
        cv2.imwrite(path, self.occupancy_map)
    if self.voxel_map is not None:
        assert path.endswith(".npy"), "Only NPY format is supported"
        # np.save(path, self.voxel_map)
        raise NotImplementedError("Voxel map is not implemented yet")

iTHORMap

iTHORMap(occupancy: ndarray, world_to_map: ndarray, map_to_world: ndarray, px_per_m: int)

Bases: ProcTHORMap

Methods:

Name Description
__call__
check_collision
from_mj_model_path

Generate a ProcTHORMap from a MuJoCo model with the open door path cleared.

get_free_points
get_free_points_by_room
load
pos_m_to_px
pos_px_to_m
safe_model_data
save
save_map

Attributes:

Name Type Description
MAP_TYPES
map_to_world
occupancy
occupancy_map
occupancy_scale_factor
occupancy_world_dims
px_per_m
room_ids_to_name
room_map
room_names_to_id
voxel_map
voxel_scale_to_world
world_to_map
Source code in molmo_spaces/utils/scene_maps.py
def __init__(
    self,
    occupancy: np.ndarray,
    world_to_map: np.ndarray,
    map_to_world: np.ndarray,
    px_per_m: int,
):
    super().__init__(
        occupancy=occupancy,
        world_to_map=world_to_map,
        map_to_world=map_to_world,
        px_per_m=px_per_m,
    )
MAP_TYPES class-attribute instance-attribute
MAP_TYPES = ['occupancy', 'voxel']
map_to_world instance-attribute
map_to_world = map_to_world
occupancy instance-attribute
occupancy = occupancy
occupancy_map property
occupancy_map
occupancy_scale_factor property
occupancy_scale_factor
occupancy_world_dims property
occupancy_world_dims
px_per_m property
px_per_m
room_ids_to_name instance-attribute
room_ids_to_name = room_ids_to_name
room_map property
room_map
room_names_to_id instance-attribute
room_names_to_id = {v: k for k, v in (items())}
voxel_map property
voxel_map
voxel_scale_to_world property
voxel_scale_to_world
world_to_map instance-attribute
world_to_map = world_to_map
__call__
__call__(r, c, map_type: str = 'occupancy')
Source code in molmo_spaces/utils/scene_maps.py
def __call__(self, r, c, map_type: str = "occupancy"):
    if map_type == "occupancy":
        position = np.zeros(3)
        h, w = self.occupancy_map.shape
        position[0] = self._occupancy_scale_factor * (c - w / 2)
        position[1] = self._occupancy_scale_factor * (r - h / 2)
        return position
    elif map_type == "voxel":
        raise NotImplementedError("Voxel map is not implemented yet")
check_collision
check_collision(pos: ndarray) -> bool | ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def check_collision(self, pos: np.ndarray) -> bool | np.ndarray:
    pos_px = self.pos_m_to_px(pos)
    in_range_mask = np.all((pos_px >= 0) & (pos_px < self.occupancy.shape), axis=1)
    ret = in_range_mask  # np.empty(len(pos), dtype=bool)
    ret[in_range_mask] = self.occupancy[pos_px[in_range_mask, 0], pos_px[in_range_mask, 1]]
    # ret[~in_range_mask] = True
    return ret
from_mj_model_path classmethod
from_mj_model_path(model_path, camera: str | None = None, agent_radius: float | None = None, px_per_m: int = 100, data: MjData | None = None, device_id: int = None, use_filament: bool = False)

Generate a ProcTHORMap from a MuJoCo model with the open door path cleared.

This method renders occupancy maps at three camera heights
  • 5.0 m: Base map with full wall geometry.
  • 2.5 m and 1.5 m: Lower views that capture the door opening, since walls might not be visible at these heights.

It computes a door mask as the area that is occupied at 2.5 m but free at 1.5 m and applies that mask to the 5.0 m map. The method also computes the transformation matrices for mapping between world and map coordinates.

Returns:

Name Type Description
ProcTHORMap

An instance with the occupancy map having the door path cleared.

Source code in molmo_spaces/utils/scene_maps.py
@classmethod
def from_mj_model_path(
    cls,
    model_path,
    camera: str | None = None,
    agent_radius: float | None = None,
    px_per_m: int = 100,
    data: MjData | None = None,
    device_id: int = None,
    use_filament: bool = False,
):
    # We make two passes of spec/model loading:
    #  1. determine which objects are more than 1.5m above the floor
    #  2. compute the occupancy map with high objects removed

    # Pass 1.
    spec = mujoco.MjSpec.from_file(model_path)
    model, data = cls.safe_model_data(spec)

    # Find floor geoms
    floor_ids = []
    for geom_id in range(model.ngeom):
        geom_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, geom_id)
        if geom_name and "floor" in geom_name.lower():
            if model.geom(geom_id).contype == 0:  # is "__VISUAL_MJT__":
                floor_ids.append(geom_id)
    assert len(floor_ids) > 0, "No floors found in the model"

    # Compute top of floor and height threshold (1.5m above floor)
    aabb_center, aabb_size = geom_aabb(model, data, floor_ids, tight_mesh=False)
    z_threshold = 1.5 + (aabb_center[2] + aabb_size[2] / 2)  # floor top + 1.5

    # Find top-level body names
    high_names = set()
    low_names = set()
    for geom_id in range(model.ngeom):
        aabb_center, aabb_size = geom_aabb(model, data, [geom_id], tight_mesh=False)
        if model.geom(geom_id).contype == 0:  # is "__VISUAL_MJT__":
            min_z = aabb_center[2] - aabb_size[2] / 2
            body_id = model.geom_bodyid[geom_id]
            body_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_BODY, body_id)

            # Get top-level body name
            parts = body_name.split("_")
            parts[-2] = "0"
            body_name = "_".join(parts)

            if min_z > z_threshold:
                high_names.add(body_name)
            else:
                low_names.add(body_name)

    # If any low-level body is low, the top-level one cannot be considered high
    high_names -= low_names

    del data, model

    # Pass 2.
    spec = mujoco.MjSpec.from_file(model_path)

    # Delete high bodies
    for body in spec.worldbody.bodies:
        body_name = body.name
        if body_name and "ceiling" in body_name.lower():
            spec.delete(body)
        elif body_name and "light" in body_name.lower():
            spec.delete(body)
        elif body_name in high_names:
            spec.delete(body)

    model, data = cls.safe_model_data(spec)

    floor_ids = []
    for geom_id in range(model.ngeom):
        geom_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, geom_id)
        if geom_name and "floor" in geom_name.lower():
            if model.geom(geom_id).contype == 0:  # is "__VISUAL_MJT__":
                floor_ids.append(geom_id)
    assert len(floor_ids) > 0, "No floors found in the model"

    if camera is None:
        aabb_center, aabb_size = geom_aabb(model, data, floor_ids, tight_mesh=False)
        aabb_size += np.array([2, 2, 0])  # add 1m buffer to each side
        cam = mujoco.MjvCamera()
        cam.type = mujoco.mjtCamera.mjCAMERA_FREE
        cam.lookat[:] = aabb_center
        cam.distance = 5.0
        cam.azimuth = 0
        cam.elevation = -90
        cam.orthographic = 1
        h, w = round(px_per_m * aabb_size[0]), round(px_per_m * aabb_size[1])
        px_per_m = h / aabb_size[0]  # recompute to account for rounding
        renderer = _get_renderer(
            model, width=w, height=h, device_id=device_id, use_filament=use_filament
        )
        renderer.update(data, cam)
        for camera in renderer.scene.camera:
            camera: mujoco.MjvGLCamera
            camera.orthographic = 1
            camera.frustum_bottom = -aabb_size[0] / 2
            camera.frustum_top = aabb_size[0] / 2
    else:
        cam_model = model.cam(camera)
        assert model.cam_orthographic[cam_model.id], "Camera must be orthographic"
        w, h = model.cam_resolution[cam_model.id]
        px_per_m = h / cam_model.fovy.item()
        renderer = _get_renderer(
            model, width=w, height=h, device_id=device_id, use_filament=use_filament
        )
        renderer.update(data, camera)

    cam_to_world = np.eye(4)
    cam_to_world[:3, 3] = renderer.scene.camera[0].pos
    camera_x_ax = np.cross(renderer.scene.camera[0].up, -renderer.scene.camera[0].forward)
    cam_to_world[:3, :3] = np.column_stack(
        (camera_x_ax, renderer.scene.camera[0].up, -renderer.scene.camera[0].forward)
    )
    assert np.allclose(cam_to_world[:3, 2], [0, 0, 1]), "Camera must be pointing straight down"

    renderer.enable_segmentation_rendering()
    seg = renderer.render()[..., 0]
    renderer.close()

    # Assemble occumancy map from segmentation
    occupancy = np.ones_like(seg, dtype=bool)
    for floor_id in floor_ids:
        occupancy &= seg != floor_id
    # Dilate to account for agent radius
    if agent_radius is not None:
        rad_px = int(agent_radius * px_per_m)
        kernel = circular_kernel(rad_px)
        occupancy = cv2.dilate(occupancy.astype(np.uint8), kernel).astype(bool)
    # cv2.imwrite("ithor_occupancy.png", occupancy * 255)

    # Remove small isolated islands (likely region outside of the wall)
    # Find connected components in the free space (black regions)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        (~occupancy).astype(np.uint8), connectivity=8
    )

    # Find the largest component
    # But sometimes the largest component is outside of the wall
    """
    if num_labels > 1:  # More than just background
        areas = stats[1:, cv2.CC_STAT_AREA]  # Skip background (label 0)
        largest_label = np.argmax(areas) + 1  # +1 because we skipped background

        # Create a mask with only the largest component
        largest_component_mask = (labels == largest_label)

        # Remove all other small components by setting them to occupied
        occupancy = occupancy | ~largest_component_mask
    """
    # cv2.imwrite("ithor_occupancy_cleaned.png", occupancy * 255)

    # transforms (x,y,z,1) in camera frame to (row, col) in the map
    cam_to_map = np.array([[0, -px_per_m, 0, h / 2], [px_per_m, 0, 0, w / 2]])
    # transforms (x,y,z,1) in world frame to (row, col) in the map
    world_to_map = cam_to_map @ inverse_homogeneous_matrix(cam_to_world)

    # converts (row, col, 1) to (x, y, 1) in camera frame, in pixels
    map_to_centered = np.array([[0, 1, -w / 2], [-1, 0, h / 2], [0, 0, 1]])
    # transforms (x, y, 1) from pixels to to (x, y, 1) in camera frame
    centered_to_cam = np.array([[1 / px_per_m, 0, 0], [0, 1 / px_per_m, 0], [0, 0, 1]])
    # transforms (x, y, 1) in camera frame to (x, y, 0) in world frame
    cam_to_world_floor = cam_to_world[:-1, [0, 1, 3]].copy()
    cam_to_world_floor[2, 2] = 0
    # transforms (row, col, 1) in map to (x, y, 0) in world frame
    map_to_world = cam_to_world_floor @ centered_to_cam @ map_to_centered

    # Flip the occupancy map. 1 is free, 0 is occupied
    occupancy = ~occupancy

    # Explicitly delete temporary MuJoCo objects before garbage collection
    # These were created for map generation and are no longer needed
    del model
    del data

    # Force garbage collection to free MuJoCo objects
    gc.collect()

    return cls(occupancy, world_to_map, map_to_world, px_per_m)
get_free_points
get_free_points() -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
def get_free_points(self) -> np.ndarray:
    free_points_px = np.argwhere(self.occupancy)
    return self.pos_px_to_m(free_points_px)
get_free_points_by_room
get_free_points_by_room(room_key: str) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
def get_free_points_by_room(self, room_key: str) -> np.ndarray:
    room_id = self.room_names_to_id[room_key]
    free_points_px = np.argwhere(self.occupancy)
    free_points_px = free_points_px[
        self.room_map[free_points_px[:, 0], free_points_px[:, 1]] == room_id
    ]
    return self.pos_px_to_m(free_points_px)
load classmethod
load(path: str, agent_radius: float | None = None)
Source code in molmo_spaces/utils/scene_maps.py
@classmethod
def load(cls, path: str, agent_radius: float | None = None):
    if path.endswith(".png"):
        # stacked images
        img = Image.open(path)
        # first channel is occupancy, second channel is room map

        world_to_map = np.array(json.loads(img.info["world_to_map"]))
        map_to_world = np.array(json.loads(img.info["map_to_world"]))
        px_per_m = int(np.ceil(json.loads(img.info["px_per_m"])))
        occupancy = np.array(img) > 0

        if agent_radius is not None:
            occupancy = ~occupancy
            rad_px = int(agent_radius * px_per_m)
            kernel = circular_kernel(rad_px)
            occupancy = cv2.dilate(occupancy.astype(np.uint8), kernel).astype(bool)
            occupancy = ~occupancy

        return cls(
            occupancy=occupancy,
            world_to_map=world_to_map,
            map_to_world=map_to_world,
            px_per_m=px_per_m,
        )
    elif path.endswith(".npz"):
        data = np.load(path)
        occupancy = data["occupancy"]

        if agent_radius is not None:
            rad_px = int(agent_radius * data["px_per_m"])
            kernel = circular_kernel(rad_px)
            occupancy = cv2.dilate(occupancy.astype(np.uint8), kernel).astype(bool)

        return cls(
            occupancy=occupancy,
            world_to_map=data["world_to_map"],
            map_to_world=data["map_to_world"],
            px_per_m=data["px_per_m"],
        )
    else:
        raise ValueError(f"Unsupported file format: {path}")
pos_m_to_px
pos_m_to_px(pos_m: ndarray) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def pos_m_to_px(self, pos_m: np.ndarray) -> np.ndarray:
    assert pos_m.ndim == 2 and pos_m.shape[-1] == 3
    return np.round(homogenize(pos_m) @ self.world_to_map.T).astype(int)
pos_px_to_m
pos_px_to_m(pos_px: ndarray) -> ndarray
Source code in molmo_spaces/utils/scene_maps.py
@single_or_batch
def pos_px_to_m(self, pos_px: np.ndarray) -> np.ndarray:
    assert pos_px.ndim == 2 and pos_px.shape[-1] == 2
    return homogenize(pos_px) @ self.map_to_world.T
safe_model_data staticmethod
safe_model_data(spec, data=None)
Source code in molmo_spaces/utils/scene_maps.py
@staticmethod
def safe_model_data(spec, data=None):
    # Delete bodies that match blacklisted asset UIDs (prevents compile errors)
    _delete_blacklisted_bodies(spec)

    # Create new model and data
    try:
        model = spec.compile()
    except ValueError as e:
        _handle_compile_error_and_blacklist(e)
        raise
    finally:
        del spec  # Explicitly free the spec

    if data is None:
        data = mujoco.MjData(model)
        mujoco.mj_forward(model, data)

    return model, data
save
save(path: str)
Source code in molmo_spaces/utils/scene_maps.py
def save(self, path: str):
    if path.endswith(".png"):
        img = Image.fromarray(self.occupancy.astype(np.uint8) * 255)

        metadata = PngInfo()
        metadata.add_text("world_to_map", json.dumps(self.world_to_map.tolist()))
        metadata.add_text("map_to_world", json.dumps(self.map_to_world.tolist()))
        metadata.add_text("px_per_m", json.dumps(self.px_per_m))
        metadata.add_text("room_ids_to_name", json.dumps(self.room_ids_to_name))
        img.save(path, pnginfo=metadata)
    elif path.endswith(".npz"):
        np.savez(
            path,
            occupancy=self.occupancy,
            world_to_map=self.world_to_map,
            map_to_world=self.map_to_world,
            px_per_m=self.px_per_m,
        )
    else:
        raise ValueError(f"Unsupported file format: {path}")
save_map
save_map(path)
Source code in molmo_spaces/utils/scene_maps.py
def save_map(self, path):
    if self.occupancy_map is not None:
        assert path.endswith(".png"), "Only PNG format is supported"
        cv2.imwrite(path, self.occupancy_map)
    if self.voxel_map is not None:
        assert path.endswith(".npy"), "Only NPY format is supported"
        # np.save(path, self.voxel_map)
        raise NotImplementedError("Voxel map is not implemented yet")

circular_kernel

circular_kernel(radius: int)
Source code in molmo_spaces/utils/scene_maps.py
def circular_kernel(radius: int):
    size = radius * 2 + 1
    kernel = np.zeros((size, size), np.uint8)
    cv2.circle(kernel, (radius, radius), radius, 1, -1)
    return kernel

sample_around_point

sample_around_point(thormap: ProcTHORMap | iTHORMap, point: ndarray, radius_range: tuple[float, float], fallback_threshold: float = 0.05, max_iter: int = 100) -> ndarray

Sample a 2D point around a given point within a given radius.

Source code in molmo_spaces/utils/scene_maps.py
def sample_around_point(
    thormap: "ProcTHORMap | iTHORMap",
    point: np.ndarray,
    radius_range: tuple[float, float],
    fallback_threshold: float = 0.05,
    max_iter: int = 100,
) -> np.ndarray:
    """
    Sample a 2D point around a given point within a given radius.
    """
    assert point.shape == (2,), "Point must be a 2D array"

    free_points = thormap.get_free_points()
    target_dist = np.linalg.norm(free_points[:, :2] - point[None], axis=1)
    # Use proper boolean indexing for array operations
    valid_mask = (target_dist > radius_range[0]) & (target_dist < radius_range[1])
    valid_points = free_points[valid_mask]
    sq_m_per_sq_px = 1 / (thormap.px_per_m**2)
    valid_neighborhood_frac = (
        len(valid_points) * sq_m_per_sq_px / (np.pi * (radius_range[1] ** 2 - radius_range[0] ** 2))
    )

    if valid_neighborhood_frac > fallback_threshold:
        for i in range(max_iter):
            # in expectation, only loop once
            batch_size = int(np.ceil(1 / valid_neighborhood_frac).item())
            theta = np.random.uniform(0, 2 * np.pi, size=batch_size)
            r = np.random.uniform(radius_range[0], radius_range[1], size=batch_size)
            sampled_points = point[None] + np.stack([r * np.cos(theta), r * np.sin(theta)], axis=1)
            sampled_points_3d = np.concatenate([sampled_points, np.zeros((batch_size, 1))], axis=1)
            valid_points = thormap.check_collision(sampled_points_3d)
            if valid_points.any():
                idxs = np.where(valid_points)[0]
                log.debug(
                    f"Sampled point from map after {i + 1} iterations, {valid_neighborhood_frac=:.1%}"
                )
                return sampled_points[idxs[0]]
        log.warning(
            f"Failed to sample a point from map after {max_iter} iterations, falling back to backup. {valid_neighborhood_frac=:.1%}"
        )
    else:
        log.warning(
            f"Less than {fallback_threshold:.0%} of the sampling area is free, sampling specific pixel. This is less robust to cross-platform variation."
        )

    return valid_points[np.random.randint(len(valid_points))]

scene_metadata_utils

Classes:

Name Description
SceneMeta

Functions:

Name Description
ensure_all_scenes_installed
get_scene_metadata

Get scene metadata from the scene path.

is_object_articulable_from_metadata

Return True if the object has at least one hinge or slide joint per scene metadata.

synsets_to_scenes_and_assets

Attributes:

Name Type Description
ctime
log
meta

ctime module-attribute

ctime = time() - ctime

log module-attribute

log = getLogger(__name__)

meta module-attribute

meta = for_dataset_split(dataset, split)

SceneMeta

Methods:

Name Description
extraction_dir
for_dataset_split
for_split
get_scene_metadata
scene_datasets
extraction_dir staticmethod
extraction_dir(data_source: str) -> Path
Source code in molmo_spaces/utils/scene_metadata_utils.py
@staticmethod
def extraction_dir(data_source: str) -> Path:
    ensure_all_scenes_installed()
    cache_dir = get_resource_manager().cache_dir
    version = get_resource_manager().versions["scenes"][data_source]
    scene_dir = cache_dir / "scenes" / data_source / version
    assert scene_dir.is_dir()
    return scene_dir
for_dataset_split cached classmethod
for_dataset_split(data_source, split)
Source code in molmo_spaces/utils/scene_metadata_utils.py
@classmethod
@functools.lru_cache
def for_dataset_split(cls, data_source, split):
    extract_dir = cls.extraction_dir(data_source)
    install_dir = get_resource_manager().symlink_dir / "scenes" / data_source
    all_scenes = get_scenes(data_source, split)[split]
    meta = {}
    for scene_idx, scene_info in all_scenes.items():
        if scene_info is None:
            continue

        if isinstance(scene_info, Path):
            scene_path = scene_info

        else:
            if all(v is None for v in scene_info.values()):
                continue

            scene_path = Path(next(v for v in scene_info.values() if v is not None))

        scene_path = extract_dir / scene_path.relative_to(install_dir)
        meta[scene_idx] = get_scene_metadata(scene_path)
    return meta
for_split classmethod
for_split(split)
Source code in molmo_spaces/utils/scene_metadata_utils.py
@classmethod
def for_split(cls, split):
    metas = {}
    for dataset in cls.scene_datasets():
        metas[dataset] = cls.for_dataset_split(dataset, split)
    return metas
get_scene_metadata staticmethod
get_scene_metadata(mj_base_scene_path: str | Path) -> dict
Source code in molmo_spaces/utils/scene_metadata_utils.py
@staticmethod
def get_scene_metadata(mj_base_scene_path: str | Path) -> dict:
    return get_scene_metadata(mj_base_scene_path)
scene_datasets staticmethod
scene_datasets() -> list[str]
Source code in molmo_spaces/utils/scene_metadata_utils.py
@staticmethod
def scene_datasets() -> list[str]:
    return sorted(set(get_resource_manager().versions["scenes"].keys()) - {"refs"})

ensure_all_scenes_installed

ensure_all_scenes_installed()
Source code in molmo_spaces/utils/scene_metadata_utils.py
def ensure_all_scenes_installed():
    global _ALL_SCENES_INSTALLED
    if not _ALL_SCENES_INSTALLED:
        print("Installing all scenes")

        get_resource_manager().install_all_for_data_type("scenes", skip_linking=True)

        _ALL_SCENES_INSTALLED = True

get_scene_metadata

get_scene_metadata(mj_base_scene_path: str | Path) -> dict | None

Get scene metadata from the scene path.

Source code in molmo_spaces/utils/scene_metadata_utils.py
def get_scene_metadata(mj_base_scene_path: str | Path) -> dict | None:
    """Get scene metadata from the scene path."""

    # Just in case we received a Path instead of a str
    mj_base_scene_path = str(mj_base_scene_path)

    assert mj_base_scene_path.endswith(".xml"), (
        f"Scene is supposed to be xml ({mj_base_scene_path} given)"
    )

    if "ceiling" in mj_base_scene_path:
        metadata_file = mj_base_scene_path.replace("_ceiling.xml", "_metadata.json")
    else:
        metadata_file = mj_base_scene_path.replace(".xml", "_metadata.json")

    if not Path(metadata_file).exists():
        # Fallback: metadata search by iteratively removing underscore-connected suffixes
        dir_path = Path(mj_base_scene_path).parent

        # First attempt to replace .xml by _metadata.json,
        # then continue removing suffixes separated by "_" until something found
        scene_name = str(Path(mj_base_scene_path).name).replace(".xml", "")
        parts = scene_name.split("_")

        while parts:
            cur_name = "_".join(parts + ["metadata.json"])
            if (dir_path / cur_name).exists():
                with open(dir_path / cur_name, "r") as f:
                    return json.load(f)

            # Not found, remove last part and try again
            parts.pop()

        log.warning(f"Scene metadata file not found for {mj_base_scene_path}")

        return None

    with open(metadata_file, "r") as f:
        metadata = json.load(f)

    return metadata

is_object_articulable_from_metadata

is_object_articulable_from_metadata(model: MjModel, scene_metadata: dict, object_name: str) -> bool

Return True if the object has at least one hinge or slide joint per scene metadata.

Uses the scene's name_map for joints and checks each corresponding MuJoCo joint type.

Source code in molmo_spaces/utils/scene_metadata_utils.py
def is_object_articulable_from_metadata(
    model: MjModel, scene_metadata: dict, object_name: str
) -> bool:
    """Return True if the object has at least one hinge or slide joint per scene metadata.

    Uses the scene's name_map for joints and checks each corresponding MuJoCo joint type.
    """
    joint_maps: dict | None = (
        scene_metadata.get("objects", {})
        .get(object_name, {})
        .get("name_map", {})
        .get("joints", None)
        if scene_metadata
        else None
    )
    if not joint_maps:
        return False
    for joint_name, _ in joint_maps.items():
        joint_type = model.joint(joint_name).type
        if joint_type == mujoco.mjtJoint.mjJNT_HINGE or joint_type == mujoco.mjtJoint.mjJNT_SLIDE:
            return True
    return False

synsets_to_scenes_and_assets cached

synsets_to_scenes_and_assets()
Source code in molmo_spaces/utils/scene_metadata_utils.py
@functools.lru_cache
def synsets_to_scenes_and_assets():
    synset_to_scenes = defaultdict(set)
    synset_to_assets = defaultdict(set)

    for dataset, index_to_meta in SceneMeta.for_split("train").items():
        for index, meta in index_to_meta.items():
            for entry in meta["objects"].values():
                asset_id = entry["asset_id"]
                from molmo_spaces.utils.object_metadata import ObjectMeta

                ometa = ObjectMeta.annotation(asset_id)
                if ometa:
                    synset = ometa["synset"]
                    synset_to_scenes[synset].add((dataset, index))
                    synset_to_assets[synset].add(asset_id)

    return synset_to_scenes, synset_to_assets

spatial_utils

Quaternions are assumed to be scalar first!

Classes:

Name Description
Transform

Functions:

Name Description
look_at

Transform

Transform(translation, rotation)

Classes:

Name Description
TClass

Convenient way to create a pure translation.

Methods:

Name Description
__mul__
apply
as_matrix
from_list
from_matrix
from_rotation
from_translation
identity
inv
look_at
to_list

Attributes:

Name Type Description
rotation
t_
translation
Source code in molmo_spaces/utils/spatial_utils.py
def __init__(self, translation, rotation) -> None:
    self.translation = np.asarray(translation, np.double).copy()
    self.rotation = copy.deepcopy(rotation)
rotation instance-attribute
rotation = deepcopy(rotation)
t_ class-attribute instance-attribute
t_ = TClass()
translation instance-attribute
translation = copy()
TClass

Convenient way to create a pure translation.

Transform.t_[x, y, z] is equivalent to Transform.from_translation(np.r_[x, y, z]).

Methods:

Name Description
__getitem__
__getitem__
__getitem__(key)
Source code in molmo_spaces/utils/spatial_utils.py
def __getitem__(self, key):
    return Transform.from_translation(np.r_[key])
__mul__
__mul__(other)
Source code in molmo_spaces/utils/spatial_utils.py
def __mul__(self, other):
    rotation = self.rotation * other.rotation
    translation = self.rotation.apply(other.translation) + self.translation
    return self.__class__(translation, rotation)
apply
apply(point)
Source code in molmo_spaces/utils/spatial_utils.py
def apply(self, point):
    return self.rotation.apply(point) + self.translation
as_matrix
as_matrix()
Source code in molmo_spaces/utils/spatial_utils.py
def as_matrix(self):
    return np.vstack(
        (
            np.c_[self.rotation.as_matrix(), self.translation],
            [0.0, 0.0, 0.0, 1.0],
        )
    )
from_list classmethod
from_list(quat_list)
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def from_list(cls, quat_list):
    return cls(quat_list[:3], Rotation.from_quat(quat_list[3:], scalar_first=True))
from_matrix classmethod
from_matrix(matrix)
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def from_matrix(cls, matrix):
    rotation = Rotation.from_matrix(matrix[:3, :3])
    translation = matrix[:3, 3]
    return cls(translation, rotation)
from_rotation classmethod
from_rotation(rotation)
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def from_rotation(cls, rotation):
    translation = np.zeros(3)
    return cls(translation, rotation)
from_translation classmethod
from_translation(translation)
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def from_translation(cls, translation):
    rotation = Rotation.identity()
    return cls(translation, rotation)
identity classmethod
identity()
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def identity(cls):
    rotation = Rotation.identity()
    translation = np.array([0.0, 0.0, 0.0])
    return cls(translation, rotation)
inv
inv()
Source code in molmo_spaces/utils/spatial_utils.py
def inv(self):
    rotation = self.rotation.inv()
    translation = -rotation.apply(self.translation)
    return self.__class__(translation, rotation)
look_at classmethod
look_at(eye, target, up)
Source code in molmo_spaces/utils/spatial_utils.py
@classmethod
def look_at(cls, eye, target, up):
    forward = np.subtract(target, eye)
    forward = np.divide(forward, np.linalg.norm(forward))

    right = np.cross(forward, up)
    if np.linalg.norm(right) < 1e-3:
        right = np.cross(forward, up + np.r_[1e-3, 0, 0])
    right = np.divide(right, np.linalg.norm(right))

    up = np.cross(right, forward)
    up = np.divide(up, np.linalg.norm(up))

    m = np.array(
        [
            [right[0], -up[0], forward[0], eye[0]],
            [right[1], -up[1], forward[1], eye[1]],
            [right[2], -up[2], forward[2], eye[2]],
            [0.0, 0.0, 0.0, 1.0],
        ]
    )

    return cls.from_matrix(m)
to_list
to_list()
Source code in molmo_spaces/utils/spatial_utils.py
def to_list(self):
    return np.r_[self.translation, self.rotation.as_quat(scalar_first=True)]

look_at

look_at(eye, center, up)
Source code in molmo_spaces/utils/spatial_utils.py
def look_at(eye, center, up):
    eye = np.asarray(eye)
    center = np.asarray(center)
    forward = center - eye
    forward /= np.linalg.norm(forward)
    right = np.cross(forward, up)
    right /= np.linalg.norm(right)
    up = np.asarray(up) / np.linalg.norm(up)
    up = np.cross(right, forward)
    m = np.eye(4, 4)
    m[:3, 0] = right
    m[:3, 1] = -up
    m[:3, 2] = forward
    m[:3, 3] = eye
    return Transform.from_matrix(m)

synset_utils

Functions:

Name Description
canonical_lemma

Return the first (most canonical) lemma for a WordNet synset name.

filter_synsets_to_remove_hyponyms
generate_all_hypernyms_with_exclusions
generate_hypernym_to_descendants
get_all_synsets_in_metadata
get_highest_relevant_hypernym
get_hypernym_to_descendants_for_all_metadata_synsets
get_hyponyms_of_synset
get_hyponyms_of_synsets
get_singleton_highest_hypernyms
get_valid_pickupable_obja_uids

Get all objaverse asset UIDs that are pickable (have valid grasp files).

get_valid_pickupable_obja_uids_excluding_benchmark

Get pickupable objaverse UIDs with benchmark assets excluded.

get_valid_receptacle_uids

Get all asset UIDs that are valid receptacles based on synset filtering.

is_hypernym_of
is_subsynset_of
is_valid_receptacle_synset

Check if a synset is a valid receptacle based on inclusion/exclusion rules.

symmetric_subsynset_of

Attributes:

Name Type Description
BENCHMARK_BLACKLIST_UIDS_PATH
EXCLUDED_HYPERNYMS
PICKUPABLE_EXCLUDED_CATEGORY_HYPERNYMS dict[str, str]
PICKUPABLE_EXCLUDED_EXACT_SYNSETS dict[str, str]
RECEPTACLE_HYPERNYM_INCLUDE_WITH_EXCLUSIONS dict[str, set[str]]
RECEPTACLE_INCLUDE_SYNSETS set[str]
VALID_PICKUPABLE_OBJA_UIDS_PATH

BENCHMARK_BLACKLIST_UIDS_PATH module-attribute

BENCHMARK_BLACKLIST_UIDS_PATH = '/weka/prior/datasets/robomolmo/asset_utility_refs/benchmark_blacklist_uids.txt'

EXCLUDED_HYPERNYMS module-attribute

EXCLUDED_HYPERNYMS = frozenset({'abstraction.n.04', 'abstraction.n.06', 'accident.n.01', 'accomplice.n.01', 'accumulation.n.04', 'act.n.02', 'acting.n.01', 'action.n.01', 'action.n.07', 'activity.n.01', 'administrative_unit.n.01', 'admirer.n.03', 'adult.n.01', 'affair.n.03', 'agaric.n.02', 'agglomeration.n.01', 'air_unit.n.01', 'alloy.n.01', 'animal_material.n.01', 'animal_order.n.01', 'animal_product.n.01', 'announcement.n.02', 'anomaly.n.02', 'aperture.n.03', 'appearance.n.01', 'appearance.n.02', 'appearance.n.04', 'application.n.03', 'approval.n.04', 'archosaur.n.01', 'arctiid.n.01', 'area.n.05', 'area.n.06', 'aristocrat.n.01', 'army_unit.n.01', 'arrangement.n.02', 'arrangement.n.03', 'art.n.03', 'art_form.n.01', 'arthropod_family.n.01', 'arthropod_genus.n.01', 'article.n.02', 'articulator.n.02', 'artifact.n.01', 'artificial_intelligence.n.01', 'artificial_language.n.01', 'artillery.n.02', 'artistic_style.n.01', 'asphodel.n.01', 'assembly.n.01', 'assembly.n.05', 'assembly.n.06', 'assets.n.01', 'assistant.n.01', 'associate.n.01', 'association.n.08', 'atom.n.02', 'atomic_theory.n.01', 'attempt.n.01', 'attendant.n.01', 'attitude.n.01', 'attribute.n.02', 'auditory_communication.n.01', 'autoloader.n.01', 'automatic_firearm.n.01', 'avoirdupois_unit.n.01', 'axis.n.06', 'back.n.08', 'base.n.01', 'basic_cognitive_process.n.01', 'basidiomycete.n.01', 'beginning.n.05', 'being.n.01', 'belief.n.01', 'benzene.n.01', 'bill.n.07', 'binary_compound.n.01', 'bioassay.n.01', 'biological_group.n.01', 'biometric_identification.n.01', 'blemish.n.01', 'body.n.02', 'body.n.04', 'body_part.n.01', 'bodybuilding.n.01', 'boundary.n.01', 'bowling.n.01', 'bramble_bush.n.01', 'bryophyte.n.01', 'business.n.01', 'businessperson.n.01', 'calcium_carbonate.n.01', 'calcium_sulphate.n.01', 'capitalist.n.02', 'capsule.n.03', 'capsule.n.05', 'carbon.n.01', 'care.n.01', 'caryophylloid_dicot_genus.n.01', 'category.n.02', 'catholic_church.n.01', 'causal_agent.n.01', 'cavity.n.02', 'center.n.01', 'center.n.04', 'center.n.06', 'central.n.01', 'ceratopsian.n.01', 'cetacean.n.01', 'change.n.03', 'change_of_location.n.01', 'change_of_state.n.01', 'character.n.04', 'character.n.08', 'chemical_phenomenon.n.01', 'chemoreceptor.n.01', 'chicory.n.04', 'child.n.01', 'child.n.02', 'chordate.n.01', 'circle.n.01', 'class.n.03', 'clef.n.01', 'clown.n.02', 'clue.n.02', 'code.n.03', 'coding_system.n.01', 'cognition.n.01', 'cognitive_factor.n.01', 'collection.n.01', 'collision.n.02', 'color.n.01', 'combatant.n.01', 'comedian.n.01', 'commodity.n.01', 'communication.n.02', 'complexity.n.01', 'component.n.03', 'composition.n.03', 'compound_leaf.n.01', 'compression.n.04', 'computer_graphics.n.01', 'computer_network.n.01', 'computer_science.n.01', 'concealment.n.03', 'concept.n.01', 'conduit.n.01', 'confinement.n.03', 'conic_section.n.01', 'connection.n.01', 'consequence.n.01', 'constitution.n.04', 'constraint.n.01', 'consumer_credit.n.01', 'consumer_goods.n.01', 'content.n.05', 'contestant.n.01', 'control.n.05', 'convex_shape.n.01', 'cook.n.01', 'cooking.n.01', 'cookout.n.01', 'coordinate_system.n.01', 'copper-base_alloy.n.01', 'correctional_institution.n.01', 'corrective.n.01', 'course.n.08', 'covering.n.02', 'crack.n.07', 'craftsman.n.03', 'creating_by_removal.n.01', 'creating_from_raw_materials.n.01', 'creation.n.01', 'creation.n.02', 'creator.n.02', 'crest.n.05', 'criminal.n.01', 'cross_section.n.01', 'crossing.n.05', 'crossopterygian.n.01', 'crosspiece.n.02', 'cuisine.n.01', 'cultivation.n.02', 'cyprinodont.n.01', 'danaid.n.01', 'dance_music.n.02', 'dark.n.01', 'database.n.01', 'decapod.n.02', 'deceiver.n.01', 'decline.n.02', 'decorativeness.n.01', 'defender.n.01', 'definite_quantity.n.01', 'deity.n.01', 'delicious.n.01', 'delivery.n.01', 'demonstration.n.05', 'depiction.n.04', 'depository.n.01', 'depression.n.08', 'design.n.02', 'design.n.04', 'detail.n.02', 'determinant.n.01', 'development.n.06', 'device.n.01', 'diapsid.n.01', 'dicot_genus.n.01', 'diet.n.01', 'difficulty.n.02', 'digit.n.01', 'direction.n.06', 'discharge.n.03', 'discipline.n.01', 'discrimination.n.02', 'disorderliness.n.01', 'display.n.05', 'district.n.01', 'ditch.n.01', 'diver.n.01', 'division.n.03', 'division.n.04', 'dresser.n.02', 'drive.n.02', 'drop.n.01', 'dry_masonry.n.01', 'dryad.n.01', 'durables.n.01', 'dwelling.n.01', 'dysphemism.n.01', 'ectoparasite.n.01', 'edge.n.03', 'edge.n.06', 'edging.n.01', 'effect.n.03', 'effort.n.02', 'egotist.n.01', 'elasmobranch.n.01', 'elasticity.n.01', 'electronic_text.n.01', 'elite.n.01', 'ellipse.n.01', 'embankment.n.01', 'emoticon.n.01', 'employee.n.01', 'enamel.n.04', 'enclosure.n.03', 'engineering.n.02', 'enlisted_person.n.01', 'enterprise.n.02', 'entertainment.n.01', 'entity.n.01', 'entree.n.01', 'escape.n.05', 'eubacteria.n.01', 'european.n.01', 'evaluator.n.01', 'even-toed_ungulate.n.01', 'event.n.01', 'evil_spirit.n.01', 'example.n.01', 'excretory_organ.n.01', 'exercise.n.01', 'exhibitionist.n.02', 'expanse.n.03', 'expedient.n.01', 'experience.n.02', 'explanation.n.02', 'explorer.n.01', 'external_body_part.n.01', 'extremity.n.01', 'extremity.n.05', 'extremum.n.02', 'exudate.n.01', 'facial_expression.n.01', 'facial_hair.n.01', 'facility.n.04', 'facing.n.03', 'failure.n.02', 'family.n.06', 'fancier.n.01', 'fare.n.04', 'farming.n.01', 'fashion.n.03', 'feature.n.02', 'feline.n.01', 'female.n.02', 'fern_ally.n.01', 'ferric_oxide.n.01', 'fibril.n.01', 'fiction.n.01', 'field.n.01', 'figuration.n.02', 'financial_gain.n.01', 'fine_arts.n.01', 'finish.n.04', 'fire.n.01', 'firing_range.n.01', 'first_class.n.02', 'flow.n.01', 'flue.n.03', 'fluid.n.02', 'font.n.01', 'foothold.n.02', 'force.n.02', 'forest.n.01', 'formation.n.01', 'formula.n.04', 'formulation.n.01', 'foundry.n.01', 'framework.n.03', 'front.n.04', 'fruitwood.n.01', 'fullerene.n.01', 'fundamental_quantity.n.01', 'gadoid.n.01', 'gain.n.04', 'game_of_chance.n.01', 'gang.n.03', 'ganoid.n.01', 'gas.n.02', 'gastropod.n.01', 'gate.n.04', 'genre.n.03', 'genus.n.02', 'geographic_point.n.01', 'geographical_area.n.01', 'geometry.n.01', 'girdle.n.01', 'glyptic_art.n.01', 'golf.n.01', 'goosefoot.n.01', 'graphics.n.02', 'greco-roman_deity.n.01', 'greek_deity.n.01', 'grip.n.06', 'groove.n.01', 'group.n.01', 'group_action.n.01', 'hair.n.01', 'happening.n.01', 'hawkmoth.n.01', 'hazard.n.01', 'head.n.04', 'health_hazard.n.01', 'health_professional.n.01', 'heating.n.01', 'hexagram.n.01', 'higher_cognitive_process.n.01', 'hiker.n.01', 'hill.n.01', 'hindrance.n.01', 'hindrance.n.02', 'hindu_deity.n.01', 'history.n.02', 'hole.n.01', 'hole.n.02', 'hole.n.05', 'homespun.n.01', 'homo.n.02', 'horn.n.07', 'housing.n.01', 'humate.n.01', 'humorist.n.01', 'hunting_dog.n.01', 'hydrocarbon.n.01', 'hydrozoan.n.01', 'hypothesis.n.02', 'idea.n.01', 'ideal.n.01', 'idler.n.01', 'illumination.n.02', 'illusion.n.01', 'illustration.n.01', 'imaginary_place.n.01', 'imagination.n.02', 'imaging.n.02', 'immateriality.n.02', 'implement.n.01', 'implementation.n.02', 'impression.n.01', 'incident.n.01', 'income.n.01', 'indefinite_quantity.n.01', 'individual.n.02', 'industry.n.02', 'influence.n.01', 'information.n.02', 'inhabitant.n.01', 'insertion.n.02', 'institution.n.01', 'intake.n.02', 'integer.n.01', 'intellectual.n.01', 'interior_decoration.n.02', 'inventiveness.n.01', 'investigator.n.02', 'iron.n.01', 'isogon.n.01', 'isopod.n.01', 'item.n.01', 'item.n.02', 'item.n.03', 'item.n.04', 'item.n.05', 'item.n.06', 'jack.n.11', 'jail.n.01', 'junction.n.04', 'juvenile.n.01', 'juxtaposition.n.01', 'killer.n.01', 'kind.n.01', 'kingdom.n.01', 'knowledge_domain.n.01', 'labor.n.02', 'laborer.n.01', 'lake.n.01', 'lamination.n.01', 'land.n.01', 'landing.n.02', 'lane.n.02', 'language.n.01', 'language_unit.n.01', 'larid.n.01', 'latex.n.01', 'lawman.n.01', 'layer.n.02', 'leader.n.01', 'leg.n.02', 'legend.n.01', 'leporid.n.01', 'level.n.05', 'life_science.n.01', 'lignite.n.01', 'likeness.n.02', 'liliid_monocot_genus.n.01', 'limit.n.04', 'limit.n.06', 'linear_unit.n.01', 'lipid.n.01', 'liquid.n.03', 'list.n.01', 'literary_composition.n.01', 'literate.n.01', 'living_quarters.n.01', 'living_thing.n.01', 'local_area_network.n.01', 'location.n.01', 'lookout.n.02', 'lottery.n.02', 'lover.n.01', 'machine.n.02', 'macromolecule.n.01', 'magnitude.n.01', 'main.n.02', 'male_aristocrat.n.01', 'male_child.n.01', 'malformation.n.02', 'man.n.01', 'manner.n.01', 'manual_labor.n.01', 'mark.n.04', 'marking.n.02', 'martial_art.n.01', 'mass_unit.n.01', 'material.n.01', 'material.n.04', 'mathematics.n.01', 'matter.n.01', 'matter.n.02', 'matter.n.03', 'matter.n.06', 'means.n.01', 'measure.n.02', 'mechanism.n.05', 'medical_procedure.n.01', 'meeting.n.01', 'membrane.n.02', 'merchant.n.01', 'metallic_element.n.01', 'message.n.01', 'message.n.02', 'military_quarters.n.01', 'military_unit.n.01', 'minimum.n.01', 'misconception.n.01', 'misfortune.n.01', 'mishap.n.02', 'mixture.n.01', 'molecular_formula.n.01', 'moneran.n.01', 'monetary_unit.n.01', 'monocot_genus.n.01', 'motion.n.06', 'motor_hotel.n.01', 'movement.n.03', 'movement.n.04', 'movement.n.11', 'multidimensional_language.n.01', 'murderer.n.01', 'music.n.01', 'musical_composition.n.01', 'musical_notation.n.01', 'musical_organization.n.01', 'musician.n.01', 'muslim.n.01', 'name.n.01', 'natural_elevation.n.01', 'natural_object.n.01', 'natural_phenomenon.n.01', 'natural_process.n.01', 'natural_science.n.01', 'negotiator.n.01', 'neritid.n.01', 'net_income.n.01', 'nidus.n.02', 'nobility.n.01', 'noise.n.01', 'nongovernmental_organization.n.01', 'nonmetal.n.01', 'nonworker.n.01', 'notation.n.01', 'notion.n.04', 'number.n.02', 'nutrient.n.02', 'object.n.01', 'object.n.04', 'obstacle.n.01', 'occultist.n.01', 'occupation.n.01', 'offspring.n.01', 'oil_paint.n.01', 'oldster.n.01', 'open_chain.n.01', 'operation.n.06', 'orchis.n.01', 'order.n.12', 'order.n.14', 'organelle.n.01', 'organic_compound.n.01', 'organism.n.01', 'organization.n.01', 'orifice.n.01', 'originality.n.01', 'ornithischian.n.01', 'orthography.n.01', 'oscine.n.01', 'ovule.n.01', 'oxide.n.01', 'pad.n.02', 'padding.n.01', 'parallelepiped.n.01', 'parasite.n.01', 'paring.n.01', 'part.n.01', 'part.n.02', 'part.n.03', 'partial_veil.n.01', 'participant.n.01', 'particle.n.02', 'particulate.n.01', 'passage.n.03', 'patron_saint.n.01', 'pedaler.n.01', 'peer.n.01', 'percept.n.01', 'perception.n.03', 'percussionist.n.01', 'performance.n.02', 'performer.n.01', 'performing_arts.n.01', 'perpendicular.n.02', 'personal_property.n.01', 'phenomenon.n.01', 'physical_entity.n.01', 'physical_phenomenon.n.01', 'physical_property.n.01', 'pictorial_representation.n.01', 'piece.n.01', 'placement.n.01', 'plain.n.01', 'plan.n.01', 'plan.n.03', 'plane_figure.n.01', 'plant_genus.n.01', 'plant_order.n.01', 'plant_organ.n.01', 'plant_part.n.01', 'plant_process.n.01', 'play.n.08', 'player.n.01', 'point_of_view.n.01', 'porcelain.n.01', 'portrayal.n.02', 'poseur.n.01', 'position.n.07', 'position.n.12', 'post.n.01', 'power.n.01', 'practice.n.01', 'practice_range.n.01', 'prayer.n.02', 'presence.n.01', 'presentation.n.02', 'preserver.n.03', 'principal.n.05', 'problem.n.02', 'procedure.n.01', 'process.n.02', 'process.n.05', 'process.n.06', 'prod.n.02', 'product.n.02', 'production.n.02', 'production.n.07', 'profile.n.05', 'program.n.07', 'programming_language.n.01', 'projection.n.04', 'property.n.01', 'property.n.02', 'property.n.04', 'property.n.05', 'proportional_font.n.01', 'propulsion.n.01', 'propulsion.n.02', 'protection.n.01', 'protocol.n.01', 'protoctist.n.01', 'psychological_feature.n.01', 'public_square.n.01', 'punctuation.n.02', 'pure_mathematics.n.01', 'push.n.01', 'quality.n.01', 'railway.n.01', 'range.n.04', 'range.n.05', 'ration.n.01', 'reaction_propulsion.n.01', 'real_property.n.01', 'rectangle.n.01', 'region.n.01', 'region.n.03', 'regular_polygon.n.01', 'relation.n.01', 'relationship.n.03', 'relative.n.01', 'religion.n.02', 'repair_shop.n.01', 'representation.n.01', 'representation.n.02', 'representational_process.n.01', 'representative.n.01', 'reproductive_cell.n.01', 'reproductive_structure.n.01', 'reptile_family.n.01', 'reptile_genus.n.01', 'reserve.n.02', 'residential_district.n.01', 'residue.n.01', 'resin.n.01', 'resource.n.03', 'respiratory_tract.n.01', 'restoration.n.06', 'retreat.n.02', 'rider.n.03', 'rig.n.03', 'right.n.01', 'robotics.n.01', 'roman_deity.n.01', 'room.n.02', 'rosid_dicot_genus.n.01', 'rotating_mechanism.n.01', 'row.n.01', 'rubber.n.01', 'ruminant.n.01', 'saint.n.01', 'salmonid.n.01', 'salt.n.01', 'sample.n.03', 'sanitary_condition.n.01', 'satirist.n.01', 'saurischian.n.01', 'saying.n.01', 'scholar.n.01', 'school.n.04', 'science.n.01', 'scientific_theory.n.01', 'scorpaenid.n.01', 'script.n.03', 'section.n.03', 'section.n.04', 'section.n.08', 'sediment.n.01', 'self-defense.n.01', 'semipermeable_membrane.n.01', 'sense_organ.n.01', 'serviceman.n.01', 'set.n.13', 'setting.n.02', 'settlement.n.06', 'sewing.n.02', 'shaft.n.08', 'shape.n.02', 'sheath.n.02', 'sheet.n.06', 'shell.n.02', 'show.n.01', 'side.n.04', 'side.n.05', 'side.n.09', 'sign.n.01', 'sign.n.11', 'signal.n.01', 'silhouette.n.02', 'situation.n.01', 'skating.n.01', 'skilled_worker.n.01', 'sleeper.n.01', 'slope.n.01', 'small_indefinite_quantity.n.01', 'small_person.n.01', 'smith.n.10', 'soapsuds.n.01', 'social_group.n.01', 'software.n.01', 'sole.n.01', 'solid.n.01', 'solid.n.03', 'solution.n.01', 'somatic_cell.n.01', 'sound.n.04', 'spatial_property.n.01', 'species.n.01', 'specimen.n.01', 'speech.n.02', 'speech_act.n.01', 'sphere.n.01', 'spirit.n.01', 'spirit.n.04', 'spiritual_being.n.01', 'splash.n.01', 'spot.n.05', 'spot.n.12', 'spring.n.03', 'square.n.01', 'squeeze.n.01', 'stable_gear.n.01', 'star.n.03', 'state.n.02', 'state_of_matter.n.01', 'statement.n.01', 'steel.n.01', 'steward.n.03', 'store.n.02', 'story.n.02', 'stratum.n.01', 'structure.n.01', 'structure.n.03', 'structural_formula.n.01', 'structure.n.04', 'styrene.n.01', 'subject.n.01', 'subjugation.n.01', 'substance.n.01', 'substance.n.07', 'substance.n.08', 'suburb.n.01', 'sum.n.01', 'superior_skill.n.01', 'support.n.03', 'supporting_structure.n.01', 'surface.n.02', 'suspension.n.01', 'sweetening.n.01', 'swine.n.01', 'symbol.n.01', 'symbol.n.02', 'synapsid.n.01', 'synthetic.n.01', 'synthetic_resin.n.01', 'system.n.01', 'system.n.06', 'system_of_measurement.n.01', 'taste.n.03', 'taxonomic_group.n.01', 'temperature_change.n.01', 'terminal.n.01', 'test.n.05', 'text.n.01', 'texture.n.01', 'theory.n.01', 'thing.n.04', 'thing.n.08', 'thing.n.12', 'thinker.n.02', 'thinking.n.01', 'thoroughfare.n.01', 'toecap.n.01', 'top.n.01', 'top.n.02', 'topping.n.01', 'tract.n.01', 'trade.n.02', 'traffic.n.01', 'transaction.n.01', 'transducer.n.01', 'transgression.n.01', 'transparent_substance.n.01', 'transportation.n.02', 'traveler.n.01', 'triangle.n.01', 'trouble.n.03', 'tube.n.01', 'type.n.04', 'underbrush.n.01', 'ungulate.n.01', 'unicameral_script.n.01', 'union_representative.n.01', 'unit.n.02', 'unit.n.03', 'unit.n.05', 'unit_of_measurement.n.01', 'universe.n.01', 'unreality.n.01', 'unsoundness.n.01', 'unwelcome_person.n.01', 'upper_class.n.01', 'upper_surface.n.01', 'user.n.01', 'utility.n.06', 'valuable.n.01', 'vapor.n.01', 'vascular_system.n.01', 'vault.n.03', 'vegetation.n.01', 'vehicular_traffic.n.01', 'veranda.n.01', 'vertical_surface.n.01', 'vicinity.n.01', 'village.n.02', 'vinyl_polymer.n.01', 'visual_communication.n.01', 'visual_percept.n.01', 'visual_perception.n.01', 'visual_property.n.01', 'visual_signal.n.01', 'vital_principle.n.01', 'vogue.n.01', 'volatile_storage.n.01', 'ware.n.01', 'waste.n.01', 'watercourse.n.03', 'way.n.06', 'wealth.n.03', 'weave.n.01', 'weightlift.n.01', 'whole.n.01', 'whole.n.02', 'window.n.08', 'woman.n.01', 'work.n.01', 'work.n.02', 'worker.n.01', 'workman.n.01', 'workplace.n.01', 'writing.n.02', 'writing.n.04', 'written_communication.n.01', 'written_symbol.n.01', 'wrongdoer.n.01', 'wrongdoing.n.02', 'yard.n.09', 'zone.n.01'})

PICKUPABLE_EXCLUDED_CATEGORY_HYPERNYMS module-attribute

PICKUPABLE_EXCLUDED_CATEGORY_HYPERNYMS: dict[str, str] = {'sculpture.n.01': 'sculpture/art object', 'model.n.04': 'scale model/replica', 'miniature.n.02': 'miniature/replica', 'sign.n.02': 'sign/placard', 'eolith.n.01': 'primitive stone implement', 'paleolith.n.01': 'primitive stone implement'}

PICKUPABLE_EXCLUDED_EXACT_SYNSETS module-attribute

PICKUPABLE_EXCLUDED_EXACT_SYNSETS: dict[str, str] = {'plaything.n.01': 'generic plaything', 'toy.n.02': 'generic toy (non-plaything sense)', 'popgun.n.01': 'toy gun', 'arrowhead.n.01': 'primitive implement', 'stone.n.02': 'building material', 'emblem.n.01': 'visual symbol', 'logo.n.01': 'visual symbol'}

RECEPTACLE_HYPERNYM_INCLUDE_WITH_EXCLUSIONS module-attribute

RECEPTACLE_HYPERNYM_INCLUDE_WITH_EXCLUSIONS: dict[str, set[str]] = {'box.n.01': set(), 'receptacle.n.01': {'beehive.n.04'}, 'pan.n.03': set(), 'vessel.n.03': {'ladle.n.01', 'bathtub.n.01', 'boiler.n.01', 'tank.n.02', 'bedpan.n.01'}, 'dish.n.01': set(), 'basket.n.01': set(), 'glass.n.02': set(), 'workbasket.n.01': set()}

RECEPTACLE_INCLUDE_SYNSETS module-attribute

RECEPTACLE_INCLUDE_SYNSETS: set[str] = frozenset({'flatware.n.01', 'glassware.n.01', 'dinnerware.n.01', 'service.n.09', 'gold_plate.n.01', 'silver_plate.n.01', 'crockery.n.01', 'place_mat.n.01', 'coaster.n.03', 'tray.n.01', 'saucer.n.02', 'platter.n.01', 'jar.n.01', 'canister.n.02', 'tin.n.02', 'case.n.05', 'baking_dish.n.01', 'mixing_bowl.n.01', 'salad_bowl.n.01', 'serving_dish.n.01', 'caddy.n.02', 'bin.n.01'})

VALID_PICKUPABLE_OBJA_UIDS_PATH module-attribute

VALID_PICKUPABLE_OBJA_UIDS_PATH = '/weka/prior/datasets/robomolmo/asset_utility_refs/valid_pickupable_obja_uids.txt'

canonical_lemma

canonical_lemma(synset_name: str) -> str

Return the first (most canonical) lemma for a WordNet synset name.

Source code in molmo_spaces/utils/synset_utils.py
def canonical_lemma(synset_name: str) -> str:
    """Return the first (most canonical) lemma for a WordNet synset name."""
    return wn.synset(synset_name).lemma_names()[0].replace("_", " ")

filter_synsets_to_remove_hyponyms

filter_synsets_to_remove_hyponyms(synsets: Sequence[str] | Sequence[Synset]) -> list[str]
Source code in molmo_spaces/utils/synset_utils.py
def filter_synsets_to_remove_hyponyms(synsets: Sequence[str] | Sequence[Synset]) -> list[str]:
    if len(synsets) == 0:
        return []

    hyper_to_descs = generate_hypernym_to_descendants(synsets=synsets)

    if isinstance(synsets[0], Synset):
        synsets = [s.name() for s in synsets]

    to_remove = set()
    for synset in synsets:
        descs = hyper_to_descs[synset]
        if len(descs) > 1:
            for desc in descs:
                if desc.name() != synset:
                    to_remove.add(desc.name())

    return list(set(synsets) - to_remove)

generate_all_hypernyms_with_exclusions

generate_all_hypernyms_with_exclusions(synset: str | Synset, excluded: set[str] | str = EXCLUDED_HYPERNYMS, include_self_synset: bool = True) -> set[Synset]
Source code in molmo_spaces/utils/synset_utils.py
def generate_all_hypernyms_with_exclusions(
    synset: str | Synset,
    excluded: set[str] | str = EXCLUDED_HYPERNYMS,
    include_self_synset: bool = True,
) -> set[Synset]:
    if synset is None:
        return set()

    if isinstance(synset, str):
        synset = wn.synset(synset)

    return set(
        h
        for hp in synset.hypernym_paths()
        for h in hp
        if (include_self_synset or h != synset) and h.name() not in excluded
    )

generate_hypernym_to_descendants

generate_hypernym_to_descendants(synsets: Sequence[str] | Sequence[Synset]) -> dict[str, list[Synset]]
Source code in molmo_spaces/utils/synset_utils.py
def generate_hypernym_to_descendants(
    synsets: Sequence[str] | Sequence[Synset],
) -> dict[str, list[Synset]]:
    if len(synsets) == 0:
        return {}

    if isinstance(synsets[0], str):
        synsets = [wn.synset(s) for s in synsets]

    synsets = set(synsets)
    synsets = [s.name() for s in synsets]

    hypernym_to_descendants = defaultdict(list)
    for s in synsets:
        s = wn.synset(s)
        paths = s.hypernym_paths()
        for hypernym in set(sum(paths, [])):
            hypernym_to_descendants[hypernym.name()].append(s)

    return hypernym_to_descendants

get_all_synsets_in_metadata

get_all_synsets_in_metadata() -> list[Synset]
Source code in molmo_spaces/utils/synset_utils.py
def get_all_synsets_in_metadata() -> list[Synset]:
    anns = ObjectMeta.annotation()
    synsets = set(ann["synset"] for ann in anns.values() if "synset" in ann) | set(
        AI2THOR_OBJECT_TYPE_TO_WORDNET_SYNSET.values()
    )
    synsets = sorted(list(set([wn.synset(s) for s in synsets])), key=lambda s: s.name())
    return synsets

get_highest_relevant_hypernym

get_highest_relevant_hypernym(synset: str | Synset, excluded: set[str] | str = EXCLUDED_HYPERNYMS) -> Synset
Source code in molmo_spaces/utils/synset_utils.py
def get_highest_relevant_hypernym(
    synset: str | Synset,
    excluded: set[str] | str = EXCLUDED_HYPERNYMS,
) -> Synset:
    if isinstance(synset, str):
        synset = wn.synset(synset)

    for hpath in synset.hypernym_paths():
        for hyp in hpath:
            if hyp.name() not in excluded:
                return hyp.name()

    return synset.name()  # return self if no non-excluded hypernyms

get_hypernym_to_descendants_for_all_metadata_synsets

get_hypernym_to_descendants_for_all_metadata_synsets()
Source code in molmo_spaces/utils/synset_utils.py
def get_hypernym_to_descendants_for_all_metadata_synsets():
    synsets = get_all_synsets_in_metadata()
    return generate_hypernym_to_descendants(synsets)

get_hyponyms_of_synset cached

get_hyponyms_of_synset(synset: str | Synset, return_strings: bool) -> set[Synset] | set[str]
Source code in molmo_spaces/utils/synset_utils.py
@lru_cache(maxsize=10000, typed=True)
def get_hyponyms_of_synset(synset: str | Synset, return_strings: bool) -> set[Synset] | set[str]:
    if isinstance(synset, str):
        synset = wn.synset(synset)

    if return_strings:
        hyps = {synset.name()}
    else:
        hyps = {synset}

    for h in synset.hyponyms():
        hyps.update(
            iter(
                get_hyponyms_of_synset(
                    h,
                    return_strings=return_strings,
                )
            )
        )

    return hyps

get_hyponyms_of_synsets

get_hyponyms_of_synsets(synsets: Iterable[str] | Iterable[Synset], return_strings: bool) -> set[Synset] | set[str]
Source code in molmo_spaces/utils/synset_utils.py
def get_hyponyms_of_synsets(
    synsets: Iterable[str] | Iterable[Synset], return_strings: bool
) -> set[Synset] | set[str]:
    hyponyms: set[Synset] | set[str] = set()
    for s in synsets:
        hyponyms.update(iter(get_hyponyms_of_synset(s, return_strings=return_strings)))

    return hyponyms

get_singleton_highest_hypernyms cached

get_singleton_highest_hypernyms()
Source code in molmo_spaces/utils/synset_utils.py
@cache
def get_singleton_highest_hypernyms():
    highest_hypernyms = Counter()
    for syn in get_all_synsets_in_metadata():
        highest_hypernyms[get_highest_relevant_hypernym(syn)] += 1

    return set([h for h in highest_hypernyms if highest_hypernyms[h] < 2])

get_valid_pickupable_obja_uids

get_valid_pickupable_obja_uids(debug: bool = False) -> list[str]

Get all objaverse asset UIDs that are pickable (have valid grasp files).

Checks for cached file at VALID_PICKUPABLE_OBJA_UIDS_PATH first to avoid expensive computation. If not found, computes and returns the list.

Parameters:

Name Type Description Default
debug bool

If True, prints 20 random samples with their short descriptions.

False

Returns:

Type Description
list[str]

List of UIDs for valid pickupable assets.

Source code in molmo_spaces/utils/synset_utils.py
def get_valid_pickupable_obja_uids(debug: bool = False) -> list[str]:
    """
    Get all objaverse asset UIDs that are pickable (have valid grasp files).

    Checks for cached file at VALID_PICKUPABLE_OBJA_UIDS_PATH first to avoid
    expensive computation. If not found, computes and returns the list.

    Args:
        debug: If True, prints 20 random samples with their short descriptions.

    Returns:
        List of UIDs for valid pickupable assets.
    """
    import os

    if os.path.exists(VALID_PICKUPABLE_OBJA_UIDS_PATH):
        with open(VALID_PICKUPABLE_OBJA_UIDS_PATH) as f:
            uid_list = [line.strip() for line in f if line.strip()]
        if debug:
            print(f"\n=== Loaded {len(uid_list)} pickupable UIDs from cache ===\n")
        return uid_list

    from molmo_spaces.utils.grasps import has_valid_pickup_grasps
    from molmo_spaces.utils.object_metadata import ObjectMeta

    valid_uids = {}

    for uid, anno in ObjectMeta.annotation().items():
        if has_valid_pickup_grasps(uid):
            valid_uids[uid] = anno

    if debug:
        import random

        sample_size = min(20, len(valid_uids))
        sample_uids = random.sample(list(valid_uids.keys()), sample_size)

        print(f"\n=== Pickable Objaverse Assets ({len(valid_uids)} total) ===")
        print(f"Random sample of {sample_size}:\n")

        for uid in sample_uids:
            anno = valid_uids[uid]
            short_descs = list(anno.get("description_short", {}).values())
            desc_str = short_descs[0] if short_descs else anno.get("category", "N/A")
            print(f"  {uid}: {desc_str}")

        print()

    return list(valid_uids.keys())

get_valid_pickupable_obja_uids_excluding_benchmark

get_valid_pickupable_obja_uids_excluding_benchmark(debug: bool = False) -> list[str]

Get pickupable objaverse UIDs with benchmark assets excluded.

Loads the benchmark blacklist from BENCHMARK_BLACKLIST_UIDS_PATH (generated by scripts/roseh/extract_benchmark_assets.py) and removes any UIDs that appear in any bench_v3 benchmark as a pickup or placement asset.

Parameters:

Name Type Description Default
debug bool

If True, prints how many assets were filtered and samples of what was removed.

False

Returns:

Type Description
list[str]

List of UIDs for valid pickupable assets not used in any benchmark.

Source code in molmo_spaces/utils/synset_utils.py
def get_valid_pickupable_obja_uids_excluding_benchmark(debug: bool = False) -> list[str]:
    """Get pickupable objaverse UIDs with benchmark assets excluded.

    Loads the benchmark blacklist from BENCHMARK_BLACKLIST_UIDS_PATH
    (generated by scripts/roseh/extract_benchmark_assets.py) and removes
    any UIDs that appear in any bench_v3 benchmark as a pickup or placement
    asset.

    Args:
        debug: If True, prints how many assets were filtered and samples
            of what was removed.

    Returns:
        List of UIDs for valid pickupable assets not used in any benchmark.
    """
    all_uids = get_valid_pickupable_obja_uids(debug=False)
    blacklist = set(_load_uid_list(BENCHMARK_BLACKLIST_UIDS_PATH))

    original_count = len(all_uids)
    filtered = [uid for uid in all_uids if uid not in blacklist]
    removed_count = original_count - len(filtered)

    if debug:
        print(
            f"\n=== Benchmark blacklist filtering ===\n"
            f"  {original_count} pickupable UIDs total\n"
            f"  {len(blacklist)} UIDs in benchmark blacklist\n"
            f"  {removed_count} removed (appeared in both)\n"
            f"  {len(filtered)} remaining after filtering\n"
        )

    return filtered

get_valid_receptacle_uids

get_valid_receptacle_uids() -> dict[str, dict]

Get all asset UIDs that are valid receptacles based on synset filtering.

Returns:

Type Description
dict[str, dict]

Dict mapping UID to annotation dict for valid receptacle assets.

Source code in molmo_spaces/utils/synset_utils.py
def get_valid_receptacle_uids() -> dict[str, dict]:
    """
    Get all asset UIDs that are valid receptacles based on synset filtering.

    Returns:
        Dict mapping UID to annotation dict for valid receptacle assets.
    """
    from molmo_spaces.utils.object_metadata import ObjectMeta

    valid_uids = {}

    for uid, anno in ObjectMeta.annotation().items():
        synset = anno.get("synset")
        if synset and is_valid_receptacle_synset(synset):
            # Also check that it's actually a receptacle
            if anno.get("receptacle", False):
                valid_uids[uid] = anno

    return valid_uids

is_hypernym_of cached

is_hypernym_of(synset: str | Synset, possible_hypernym: str | Synset) -> bool
Source code in molmo_spaces/utils/synset_utils.py
@lru_cache(maxsize=10000, typed=True)
def is_hypernym_of(synset: str | Synset, possible_hypernym: str | Synset) -> bool:
    if isinstance(synset, str):
        synset = wn.synset(synset)

    if isinstance(possible_hypernym, str):
        possible_hypernym = wn.synset(possible_hypernym)

    return possible_hypernym in synset.lowest_common_hypernyms(possible_hypernym)

is_subsynset_of

is_subsynset_of(synset: str | Synset, other_synset: str | Synset) -> bool
Source code in molmo_spaces/utils/synset_utils.py
def is_subsynset_of(synset: str | Synset, other_synset: str | Synset) -> bool:
    return is_hypernym_of(synset=synset, possible_hypernym=other_synset)

is_valid_receptacle_synset

is_valid_receptacle_synset(synset: str | Synset) -> bool

Check if a synset is a valid receptacle based on inclusion/exclusion rules.

The cached valid set already contains all hyponyms of included hypernyms, so a simple set membership check is sufficient.

Parameters:

Name Type Description Default
synset str | Synset

A WordNet synset or synset name string

required

Returns:

Type Description
bool

True if the synset is a valid receptacle type

Source code in molmo_spaces/utils/synset_utils.py
def is_valid_receptacle_synset(synset: str | Synset) -> bool:
    """
    Check if a synset is a valid receptacle based on inclusion/exclusion rules.

    The cached valid set already contains all hyponyms of included hypernyms,
    so a simple set membership check is sufficient.

    Args:
        synset: A WordNet synset or synset name string

    Returns:
        True if the synset is a valid receptacle type
    """
    if synset is None:
        return False

    if isinstance(synset, Synset):
        synset = synset.name()

    return synset in _get_all_valid_receptacle_synsets()

symmetric_subsynset_of

symmetric_subsynset_of(synset: str | Synset, other_synset: str | Synset) -> bool
Source code in molmo_spaces/utils/synset_utils.py
def symmetric_subsynset_of(synset: str | Synset, other_synset: str | Synset) -> bool:
    return is_hypernym_of(synset=synset, possible_hypernym=other_synset) or is_hypernym_of(
        synset=other_synset, possible_hypernym=synset
    )

task_relevant_objects_and_workspace_utils

Derive task-relevant object names and workspace center from task config fields.

Single source of truth for which objects cameras must see and what defines the workspace center. Called from: - Task samplers (resolve_visibility_object, get_workspace_center) during data generation - create_json_benchmark.py to populate EpisodeSpec.task_relevant_objects - Eval camera system for visibility checks and workspace center computation

Accepts either a pydantic config object or a plain dict.

Functions:

Name Description
compute_workspace_center

Compute the workspace center as the centroid of named 3-D positions.

compute_workspace_center_from_object_poses

Compute workspace center from serialized object poses (e.g. JSON episode data).

get_task_relevant_objects

Return the list of object body names that are relevant for this task.

compute_workspace_center

compute_workspace_center(positions: dict[str, ndarray]) -> ndarray

Compute the workspace center as the centroid of named 3-D positions.

This is the shared implementation used by both live task samplers (positions from the environment) and the eval camera system (positions from JSON episode data).

Parameters:

Name Type Description Default
positions dict[str, ndarray]

Mapping of label -> 3-D position array. Typical keys are object body names from :func:get_task_relevant_objects plus "gripper" for the end-effector. Must contain at least one entry.

required

Returns:

Type Description
ndarray

3-D centroid (mean) of all positions.

Source code in molmo_spaces/utils/task_relevant_objects_and_workspace_utils.py
def compute_workspace_center(positions: dict[str, np.ndarray]) -> np.ndarray:
    """Compute the workspace center as the centroid of named 3-D positions.

    This is the shared implementation used by both live task samplers (positions
    from the environment) and the eval camera system (positions from JSON episode
    data).

    Args:
        positions: Mapping of label -> 3-D position array.  Typical keys are
            object body names from :func:`get_task_relevant_objects` plus
            ``"gripper"`` for the end-effector.  Must contain at least one entry.

    Returns:
        3-D centroid (mean) of all positions.
    """
    pts = list(positions.values())
    if not pts:
        raise ValueError("positions dict must contain at least one entry")
    return np.mean(pts, axis=0)

compute_workspace_center_from_object_poses

compute_workspace_center_from_object_poses(object_names: list[str], object_poses: dict[str, list[float]], gripper_pos: ndarray | None = None) -> ndarray

Compute workspace center from serialized object poses (e.g. JSON episode data).

Convenience wrapper around :func:compute_workspace_center for the eval path, where positions come from EpisodeSpec.scene_modifications.object_poses (each value is [x, y, z, qw, qx, qy, qz]).

Parameters:

Name Type Description Default
object_names list[str]

Body names whose positions should contribute (typically from :func:get_task_relevant_objects or EpisodeSpec.task_relevant_objects).

required
object_poses dict[str, list[float]]

Mapping of body name to 7-D pose [x, y, z, qw, qx, qy, qz].

required
gripper_pos ndarray | None

Optional gripper position to include.

None

Returns:

Type Description
ndarray

3-D centroid.

Source code in molmo_spaces/utils/task_relevant_objects_and_workspace_utils.py
def compute_workspace_center_from_object_poses(
    object_names: list[str],
    object_poses: dict[str, list[float]],
    gripper_pos: np.ndarray | None = None,
) -> np.ndarray:
    """Compute workspace center from serialized object poses (e.g. JSON episode data).

    Convenience wrapper around :func:`compute_workspace_center` for the eval
    path, where positions come from ``EpisodeSpec.scene_modifications.object_poses``
    (each value is ``[x, y, z, qw, qx, qy, qz]``).

    Args:
        object_names: Body names whose positions should contribute (typically
            from :func:`get_task_relevant_objects` or ``EpisodeSpec.task_relevant_objects``).
        object_poses: Mapping of body name to 7-D pose ``[x, y, z, qw, qx, qy, qz]``.
        gripper_pos: Optional gripper position to include.

    Returns:
        3-D centroid.
    """
    positions: dict[str, np.ndarray] = {}
    for name in object_names:
        pose = object_poses.get(name)
        if pose is not None:
            positions[name] = np.asarray(pose[:3], dtype=float)
    if gripper_pos is not None:
        positions["gripper"] = np.asarray(gripper_pos, dtype=float)
    if not positions:
        raise ValueError(
            f"No positions found for any of {object_names} in object_poses "
            f"(available: {list(object_poses.keys())})"
        )
    return compute_workspace_center(positions)

get_task_relevant_objects

get_task_relevant_objects(task_config: Any) -> list[str]

Return the list of object body names that are relevant for this task.

These are the objects that cameras should be able to see (visibility constraints) and whose positions define the workspace center.

Parameters:

Name Type Description Default
task_config Any

A task config object (e.g. PickTaskConfig) or a dict with the same keys (as stored in EpisodeSpec.task).

required

Returns:

Type Description
list[str]

Deduplicated list of object body names, in stable insertion order.

Source code in molmo_spaces/utils/task_relevant_objects_and_workspace_utils.py
def get_task_relevant_objects(task_config: Any) -> list[str]:
    """Return the list of object body names that are relevant for this task.

    These are the objects that cameras should be able to see (visibility
    constraints) and whose positions define the workspace center.

    Args:
        task_config: A task config object (e.g. PickTaskConfig) or a dict
            with the same keys (as stored in EpisodeSpec.task).

    Returns:
        Deduplicated list of object body names, in stable insertion order.
    """
    seen: set[str] = set()
    result: list[str] = []

    def _add(name: str | None) -> None:
        if name and name not in seen:
            seen.add(name)
            result.append(name)

    # Pickup / target object (pick, pnp, open/close, nav)
    _add(_get(task_config, "pickup_obj_name"))

    # Place receptacle (pnp, next-to, color)
    _add(_get(task_config, "place_receptacle_name"))

    # Distractor receptacles (pnp color)
    for name in _get(task_config, "other_receptacle_names") or []:
        _add(name)

    return result

test_utils

Shared utilities for data generation tests (Franka, RUM, etc.).

Functions:

Name Description
assert_obs_scene_match

Assert that the obs_scenes of two trajectory groups are equal.

assert_observations_match

Compare actual and expected observations using Structural Similarity Index (SSIM).

assert_python_types_equal

General (recursive) function to assert that two python objects are equal, with tolerance applied for floats.

compare_h5_groups

Recursively compare two HDF5 groups and check for differences.

print_profiling_summary

Print a formatted summary of profiling results.

run_policy_for_steps

Run a policy on a task for a fixed number of steps, following pipeline.py API.

run_task_for_steps_with_observations

Run a policy on a task for a fixed number of steps and return both qpos and observations.

save_observation_comparison

Save visual observation comparisons including actual, expected, and difference images.

save_visual_observations

Save visual observations as viewable PNG images for debugging.

verify_and_compare_camera_observations

Verify observation structure and compare camera observations against saved test data.

verify_and_compare_camera_observations_after_steps

Verify and compare camera observations after running policy steps against saved test data.

verify_video_fps

Assert all videos in a directory have the expected FPS.

Attributes:

Name Type Description
log

log module-attribute

log = getLogger(__name__)

assert_obs_scene_match

assert_obs_scene_match(g1: Group, g2: Group, atol=0.001)

Assert that the obs_scenes of two trajectory groups are equal.

Parameters:

Name Type Description Default
g1 Group

h5py.Group of the first trajectory

required
g2 Group

h5py.Group of the second trajectory

required
Source code in molmo_spaces/utils/test_utils.py
def assert_obs_scene_match(g1: h5py.Group, g2: h5py.Group, atol=0.001):
    """
    Assert that the obs_scenes of two trajectory groups are equal.

    Args:
        g1: h5py.Group of the first trajectory
        g2: h5py.Group of the second trajectory
    """
    obs_scene_1 = json.loads(g1["obs_scene"][()].decode("utf-8").rstrip("\x00"))
    obs_scene_2 = json.loads(g2["obs_scene"][()].decode("utf-8").rstrip("\x00"))
    if set(obs_scene_1.keys()) != set(obs_scene_2.keys()):
        obs_scene_1_keys = set(obs_scene_1.keys())
        obs_scene_2_keys = set(obs_scene_2.keys())
        f2_missing = obs_scene_1_keys - obs_scene_2_keys
        f1_missing = obs_scene_2_keys - obs_scene_1_keys
        raise AssertionError(
            f"obs_scene keys mismatch: {g1.file.filename}:{g1.name}/obs_scene missing {f1_missing} "
            f"and {g2.file.filename}:{g2.name}/obs_scene missing {f2_missing}"
        )

    try:
        for k, v1 in obs_scene_1.items():
            # frozen_config is a pickled object, so don't compare it
            if k == "frozen_config":
                continue

            v2 = obs_scene_2[k]
            assert_python_types_equal(f"obs_scene[{k}]", v1, v2, atol=atol)
    except AssertionError as e:
        raise AssertionError(
            f"obs_scene mismatch: {g1.file.filename}:{g1.name} vs {g2.file.filename}:{g2.name}: {e}"
        ) from e

assert_observations_match

assert_observations_match(actual_obs, expected_obs, sensor_name, atol=0, rtol=1e-07, ssim_threshold=0.9)

Compare actual and expected observations using Structural Similarity Index (SSIM).

Uses SSIM to compare images perceptually rather than pixel-by-pixel, which is more robust to minor rendering variations while still catching meaningful visual differences.

Parameters:

Name Type Description Default
actual_obs

Actual observation array

required
expected_obs

Expected observation array

required
sensor_name

Name of the sensor (for error messages)

required
atol

Unused, kept for API compatibility

0
rtol

Unused, kept for API compatibility

1e-07

Raises:

Type Description
AssertionError

If observations have meaningful visual differences (low SSIM)

Source code in molmo_spaces/utils/test_utils.py
def assert_observations_match(
    actual_obs, expected_obs, sensor_name, atol=0, rtol=1e-7, ssim_threshold=0.9
):
    """
    Compare actual and expected observations using Structural Similarity Index (SSIM).

    Uses SSIM to compare images perceptually rather than pixel-by-pixel, which is more
    robust to minor rendering variations while still catching meaningful visual differences.

    Args:
        actual_obs: Actual observation array
        expected_obs: Expected observation array
        sensor_name: Name of the sensor (for error messages)
        atol: Unused, kept for API compatibility
        rtol: Unused, kept for API compatibility

    Raises:
        AssertionError: If observations have meaningful visual differences (low SSIM)
    """
    # Compute SSIM for each channel and average
    # SSIM returns a value between -1 and 1, where 1 is perfect similarity
    # Use channel_axis=2 to compute SSIM per color channel and average
    ssim_score = ssim(
        actual_obs,
        expected_obs,
        channel_axis=2,
        data_range=255,  # For uint8 images
    )

    # Also compute basic pixel difference stats for debugging
    diff = np.abs(actual_obs.astype(np.int32) - expected_obs.astype(np.int32))
    diff_max = np.max(diff)
    diff_mean = np.mean(diff)
    num_different_pixels = np.sum(diff > 0)
    total_pixels = diff.size
    percent_different = 100 * num_different_pixels / total_pixels

    # Define SSIM threshold
    # SSIM > 0.90 is considered very similar
    # SSIM > 0.95 is nearly identical (which these images should be since they are generated from the same policy)
    # high GPU variance is killing me
    MIN_SSIM_THRESHOLD = ssim_threshold

    if ssim_score < MIN_SSIM_THRESHOLD:
        # Print detailed statistics
        diff_sum = np.sum(diff)
        print(f"\n[DIFF] Sensor {sensor_name}:")
        print(f"  SSIM score: {ssim_score:.4f} (threshold: {MIN_SSIM_THRESHOLD})")
        print(f"  Total difference sum: {diff_sum}")
        print(f"  Max pixel difference: {diff_max}")
        print(f"  Mean pixel difference: {diff_mean:.4f}")
        print(
            f"  Different pixels: {num_different_pixels}/{total_pixels} ({percent_different:.2f}%)"
        )

        raise AssertionError(
            f"Sensor {sensor_name} observations have meaningful structural differences from saved test data. "
            f"SSIM score: {ssim_score:.4f} (threshold: {MIN_SSIM_THRESHOLD})"
        )

assert_python_types_equal

assert_python_types_equal(pfx: str, v1, v2, atol=0.001)

General (recursive) function to assert that two python objects are equal, with tolerance applied for floats. Works for native python primitives only.

Source code in molmo_spaces/utils/test_utils.py
def assert_python_types_equal(pfx: str, v1, v2, atol=0.001):
    """
    General (recursive) function to assert that two python objects are equal, with tolerance applied for floats.
    Works for native python primitives only.
    """
    if type(v1) is not type(v2):
        raise AssertionError(f"{pfx} type mismatch: {type(v1)} vs {type(v2)}")

    if isinstance(v1, str | bytes | bool | int):
        assert v1 == v2, f"{pfx} value mismatch: {v1} vs {v2}"
    elif isinstance(v1, float):
        assert abs(v1 - v2) < atol, f"{pfx} value mismatch: {v1} vs {v2}"
    elif isinstance(v1, list | tuple):
        assert len(v1) == len(v2), f"{pfx} length mismatch: {len(v1)} vs {len(v2)}"
        for i in range(len(v1)):
            assert_python_types_equal(f"{pfx}[{i}]", v1[i], v2[i], atol=atol)
    elif isinstance(v1, dict):
        assert set(v1.keys()) == set(v2.keys()), (
            f"{pfx} keys mismatch: {set(v1.keys())} vs {set(v2.keys())}"
        )
        for k in v1:
            assert_python_types_equal(f"{pfx}[{k}]", v1[k], v2[k], atol=atol)
    else:
        raise AssertionError(f"{pfx} unknown type: {type(v1)}")

compare_h5_groups

compare_h5_groups(g1, g2, path='/', atol=1e-06, ignore_paths=None)

Recursively compare two HDF5 groups and check for differences.

Parameters:

Name Type Description Default
g1

First HDF5 group

required
g2

Second HDF5 group

required
path

Current path in the HDF5 structure (for error messages)

'/'
atol

Absolute tolerance for numerical comparisons

1e-06
ignore_paths

Optional set/list of path suffixes to skip (e.g., {"object_image_points"})

None
Source code in molmo_spaces/utils/test_utils.py
def compare_h5_groups(g1, g2, path="/", atol=1e-6, ignore_paths=None):
    """Recursively compare two HDF5 groups and check for differences.

    Args:
        g1: First HDF5 group
        g2: Second HDF5 group
        path: Current path in the HDF5 structure (for error messages)
        atol: Absolute tolerance for numerical comparisons
        ignore_paths: Optional set/list of path suffixes to skip (e.g., {"object_image_points"})
    """
    ignore_paths = ignore_paths or set()
    for name in g1:
        item_path = path + name

        # Skip paths that match any ignore pattern
        if any(item_path.endswith(ignore_suffix) for ignore_suffix in ignore_paths):
            continue

        assert name in g2, f"Missing in second file: {item_path}"

        obj1 = g1[name]
        obj2 = g2[name]

        assert type(obj1) is type(obj2), (
            f"Type mismatch at {item_path}: {type(obj1)} vs {type(obj2)}"
        )

        # Both are groups → recurse
        if isinstance(obj1, h5py.Group) and isinstance(obj2, h5py.Group):
            compare_h5_groups(
                obj1, obj2, path=item_path + "/", atol=atol, ignore_paths=ignore_paths
            )

        # Both are datasets → compare
        elif isinstance(obj1, h5py.Dataset) and isinstance(obj2, h5py.Dataset):
            d1 = obj1[()]
            d2 = obj2[()]

            assert type(d1) is type(d2), f"Type mismatch at {item_path}: {type(d1)} vs {type(d2)}"

            # only check values for numerical arrays and scalars, strings get arbitrarily complicated
            if isinstance(d1, np.ndarray) and isinstance(d2, np.ndarray):
                assert d1.shape == d2.shape, (
                    f"Shape mismatch at {item_path}: {d1.shape} vs {d2.shape}"
                )
                assert d1.dtype == d2.dtype, (
                    f"Type mismatch at {item_path}: {d1.dtype} vs {d2.dtype}"
                )

                # don't check values for byte-encoded dicts, since that gets arbitrarily complicated
                if d1.dtype != np.uint8:
                    if not np.allclose(d1, d2, atol=atol, equal_nan=True):
                        if np.issubdtype(d1.dtype, np.bool_):
                            n_diff = np.sum(d1 != d2)
                            msg = (
                                f"Boolean mismatch at {item_path}, {n_diff}/{d1.size} elems differ"
                            )
                        else:
                            err = np.abs(d1 - d2).max()
                            msg = f"Data mismatch at {item_path}, w/ max err {err}"
                        log.warning(f"{d1}")
                        log.warning(f"{d2}")
                        log.warning(msg)
                        raise AssertionError(msg)

            elif isinstance(d1, float | int) and isinstance(d2, float | int):
                if not np.allclose(d1, d2, atol=atol):
                    err = np.abs(d1 - d2).max()
                    log.warning(f"Data mismatch at {item_path}, w/ max err {err}")
                    log.warning(f"{d1}")
                    log.warning(f"{d2}")
                    raise AssertionError

            # Compare attributes
            for attr in obj1.attrs:
                assert attr in obj2.attrs, (
                    f"Missing attribute '{attr}' in {item_path} (second file)"
                )
                a1 = obj1.attrs[attr]
                a2 = obj2.attrs[attr]
                if isinstance(a1, bytes):
                    a1 = a1.decode(errors="ignore")
                if isinstance(a2, bytes):
                    a2 = a2.decode(errors="ignore")
                if isinstance(a1, np.ndarray) and isinstance(a2, np.ndarray):
                    assert np.allclose(a1, a2, equal_nan=True, atol=atol), (
                        f"Attribute mismatch at {item_path}/{attr}"
                    )
                else:
                    assert a1 == a2, f"Attribute mismatch at {item_path}/{attr}: {a1!r} vs {a2!r}"

    # Check for extra keys in g2
    for name in g2:
        assert name in g1, f"Missing in first file: {path + name}"

print_profiling_summary

print_profiling_summary(profiler)

Print a formatted summary of profiling results.

Parameters:

Name Type Description Default
profiler

Profiler instance with collected timing data

required

Returns:

Name Type Description
str

Formatted summary string

Source code in molmo_spaces/utils/test_utils.py
def print_profiling_summary(profiler):
    """
    Print a formatted summary of profiling results.

    Args:
        profiler: Profiler instance with collected timing data

    Returns:
        str: Formatted summary string
    """
    if profiler is None:
        return "No profiler available"

    # Get all profiled operations
    operations = list(profiler._avg_time.keys())
    if not operations:
        return "No profiling data collected"

    # Sort by total time (descending)
    operations.sort(key=lambda k: profiler.get_avg_time(k) * profiler.get_n(k), reverse=True)

    # Build summary
    lines = []
    lines.append("\n" + "=" * 80)
    lines.append("PROFILING SUMMARY".center(80))
    lines.append("=" * 80)
    lines.append(f"{'Operation':<40} {'Calls':>8} {'Avg Time':>12} {'Total Time':>12}")
    lines.append("-" * 80)

    for op in operations:
        avg_time = profiler.get_avg_time(op)
        n_calls = profiler.get_n(op)
        total_time = avg_time * n_calls
        lines.append(f"{op:<40} {n_calls:>8} {avg_time:>11.4f}s {total_time:>11.4f}s")

    lines.append("=" * 80)

    return "\n".join(lines)

run_policy_for_steps

run_policy_for_steps(task, policy, num_steps=10, profiler=None)

Run a policy on a task for a fixed number of steps, following pipeline.py API.

Parameters:

Name Type Description Default
task

The task instance

required
policy

The policy instance

required
num_steps

Number of steps to run

10
profiler

Optional profiler instance to track timing

None

Returns:

Name Type Description
tuple

(initial_qpos, final_qpos) as numpy arrays

Source code in molmo_spaces/utils/test_utils.py
def run_policy_for_steps(task, policy, num_steps=10, profiler=None):
    """
    Run a policy on a task for a fixed number of steps, following pipeline.py API.

    Args:
        task: The task instance
        policy: The policy instance
        num_steps: Number of steps to run
        profiler: Optional profiler instance to track timing

    Returns:
        tuple: (initial_qpos, final_qpos) as numpy arrays
    """
    if profiler is not None:
        profiler.start("test_policy_execution")

    # Register policy with task (following pipeline.py line 322)
    task.register_policy(policy)

    # Reset task to get initial observation (following pipeline.py line 148)
    if profiler is not None:
        profiler.start("test_task_reset")
    observation, _info = task.reset()
    if profiler is not None:
        profiler.end("test_task_reset")

    # Get initial joint positions AFTER reset (this is the state from which the policy runs)
    robot = task.env.robots[0]
    robot_view = robot.robot_view

    # Get all move groups in consistent order and concatenate their qpos
    move_group_ids = robot_view.move_group_ids()
    initial_qpos_dict = robot_view.get_qpos_dict(move_group_ids)
    initial_qpos = np.concatenate([initial_qpos_dict[mg_id] for mg_id in move_group_ids])

    # Run policy for specified number of steps
    for _ in range(num_steps):
        # Get action from policy (following pipeline.py line 163)
        if profiler is not None:
            profiler.start("test_policy_get_action")
        action_cmd = policy.get_action(observation)
        if profiler is not None:
            profiler.end("test_policy_get_action")

        # Step the task
        if profiler is not None:
            profiler.start("test_task_step")
        observation, reward, terminal, truncated, infos = task.step(action_cmd)
        if profiler is not None:
            profiler.end("test_task_step")

        # Check if done (following pipeline.py line 159)
        if task.is_done():
            break

    # Get final joint positions using same move group order
    final_qpos_dict = robot_view.get_qpos_dict(move_group_ids)
    final_qpos = np.concatenate([final_qpos_dict[mg_id] for mg_id in move_group_ids])

    if profiler is not None:
        profiler.end("test_policy_execution")

    return initial_qpos, final_qpos

run_task_for_steps_with_observations

run_task_for_steps_with_observations(task, policy, num_steps=10, profiler=None)

Run a policy on a task for a fixed number of steps and return both qpos and observations.

This extends run_policy_for_steps by also capturing initial and final observations after running steps. Useful for testing that observations remain deterministic across runs and that they change appropriately.

Parameters:

Name Type Description Default
task

The task instance

required
policy

The policy instance

required
num_steps

Number of steps to run

10
profiler

Optional profiler instance to track timing

None

Returns:

Name Type Description
tuple

(initial_qpos, final_qpos, initial_obs_dict, final_obs_dict) where: - initial_qpos: numpy array of initial joint positions - final_qpos: numpy array of final joint positions - initial_obs_dict: dictionary of initial observations from the single environment - final_obs_dict: dictionary of final observations from the single environment after running steps

Source code in molmo_spaces/utils/test_utils.py
def run_task_for_steps_with_observations(task, policy, num_steps=10, profiler=None):
    """
    Run a policy on a task for a fixed number of steps and return both qpos and observations.

    This extends run_policy_for_steps by also capturing initial and final observations after running steps.
    Useful for testing that observations remain deterministic across runs and that they change appropriately.

    Args:
        task: The task instance
        policy: The policy instance
        num_steps: Number of steps to run
        profiler: Optional profiler instance to track timing

    Returns:
        tuple: (initial_qpos, final_qpos, initial_obs_dict, final_obs_dict) where:
            - initial_qpos: numpy array of initial joint positions
            - final_qpos: numpy array of final joint positions
            - initial_obs_dict: dictionary of initial observations from the single environment
            - final_obs_dict: dictionary of final observations from the single environment after running steps
    """
    if profiler is not None:
        profiler.start("test_policy_execution_with_obs")

    # Register policy with task
    task.register_policy(policy)

    # Reset task to get initial observation
    if profiler is not None:
        profiler.start("test_task_reset")
    observation, _info = task.reset()
    if profiler is not None:
        profiler.end("test_task_reset")

    # Get initial joint positions AFTER reset (this is the state from which the policy runs)
    robot = task.env.robots[0]
    robot_view = robot.robot_view

    # Get all move groups in consistent order and concatenate their qpos
    move_group_ids = robot_view.move_group_ids()
    initial_qpos_dict = robot_view.get_qpos_dict(move_group_ids)
    initial_qpos = np.concatenate([initial_qpos_dict[mg_id] for mg_id in move_group_ids])

    # Extract initial observations
    # observation is list[dict[str, Any]] from task.reset() - a list of obs dicts, one per env
    # Get the first (and only) environment's observations
    initial_obs_dict = observation[0]

    # Run policy for specified number of steps
    for _ in range(num_steps):
        # Get action from policy
        if profiler is not None:
            profiler.start("test_policy_get_action")
        action_cmd = policy.get_action(observation)
        if profiler is not None:
            profiler.end("test_policy_get_action")

        # Step the task
        if profiler is not None:
            profiler.start("test_task_step")
        observation, reward, terminal, truncated, infos = task.step(action_cmd)
        if profiler is not None:
            profiler.end("test_task_step")

        # Check if done
        if task.is_done():
            break

    # Get final joint positions using same move group order
    final_qpos_dict = robot_view.get_qpos_dict(move_group_ids)
    final_qpos = np.concatenate([final_qpos_dict[mg_id] for mg_id in move_group_ids])

    # Extract final observations
    # observation is list[dict[str, Any]] from task.step() - a list of obs dicts, one per env
    # Get the first (and only) environment's observations
    final_obs_dict = observation[0]

    if profiler is not None:
        profiler.end("test_policy_execution_with_obs")

    return initial_qpos, final_qpos, initial_obs_dict, final_obs_dict

save_observation_comparison

save_observation_comparison(obs_dict, expected_dict, output_dir, prefix='comparison')

Save visual observation comparisons including actual, expected, and difference images.

Parameters:

Name Type Description Default
obs_dict

Dictionary of actual observations from a single environment

required
expected_dict

Dictionary of expected observations with same structure

required
output_dir

Path object or string for the debug output directory

required
prefix

Prefix for the saved image filenames

'comparison'
Source code in molmo_spaces/utils/test_utils.py
def save_observation_comparison(obs_dict, expected_dict, output_dir, prefix="comparison"):
    """
    Save visual observation comparisons including actual, expected, and difference images.

    Args:
        obs_dict: Dictionary of actual observations from a single environment
        expected_dict: Dictionary of expected observations with same structure
        output_dir: Path object or string for the debug output directory
        prefix: Prefix for the saved image filenames
    """
    from molmo_spaces.utils.depth_utils import visualize_depth_error, visualize_depth_image

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for sensor_name in obs_dict:
        if "camera" in sensor_name and "sensor_param" not in sensor_name:
            if sensor_name in expected_dict:
                actual_data = obs_dict[sensor_name]
                expected_data = expected_dict[sensor_name]

                if isinstance(actual_data, np.ndarray) and isinstance(expected_data, np.ndarray):
                    # Handle depth sensors differently
                    if sensor_name.endswith("_depth"):
                        # Convert float16 to float32 if needed
                        if expected_data.dtype == np.float16:
                            expected_data = expected_data.astype(np.float32)

                        # Visualize actual depth
                        visualize_depth_image(
                            actual_data,
                            f"{sensor_name} - Actual",
                            save_path=output_dir / f"{prefix}_{sensor_name}_actual.png",
                        )

                        # Visualize expected depth
                        visualize_depth_image(
                            expected_data,
                            f"{sensor_name} - Expected",
                            save_path=output_dir / f"{prefix}_{sensor_name}_expected.png",
                        )

                        # Visualize error/difference
                        error = np.abs(actual_data - expected_data)
                        visualize_depth_error(
                            expected_data,
                            actual_data,
                            error,
                            f"{sensor_name} - Comparison Error",
                            save_path=output_dir / f"{prefix}_{sensor_name}_error.png",
                        )

                        print(f"Saved depth visualizations for {sensor_name} at {output_dir}")
                    else:
                        # RGB sensors - save as images
                        if actual_data.dtype != np.uint8:
                            actual_data = np.clip(actual_data, 0, 255).astype(np.uint8)
                        img_actual = Image.fromarray(actual_data)
                        img_actual.save(output_dir / f"{prefix}_{sensor_name}_actual.png")

                        if expected_data.dtype != np.uint8:
                            expected_data = np.clip(expected_data, 0, 255).astype(np.uint8)
                        img_expected = Image.fromarray(expected_data)
                        img_expected.save(output_dir / f"{prefix}_{sensor_name}_expected.png")
                        print(
                            f"Saved {prefix}_{sensor_name}_actual.png and {prefix}_{sensor_name}_expected.png at path {output_dir}"
                        )

save_visual_observations

save_visual_observations(obs_dict, output_dir, prefix='obs')

Save visual observations as viewable PNG images for debugging.

Parameters:

Name Type Description Default
obs_dict

Dictionary of observations from a single environment

required
output_dir

Path object or string for the debug output directory

required
prefix

Prefix for the saved image filenames (e.g., "obs", "expected", "diff")

'obs'
Source code in molmo_spaces/utils/test_utils.py
def save_visual_observations(obs_dict, output_dir, prefix="obs"):
    """
    Save visual observations as viewable PNG images for debugging.

    Args:
        obs_dict: Dictionary of observations from a single environment
        output_dir: Path object or string for the debug output directory
        prefix: Prefix for the saved image filenames (e.g., "obs", "expected", "diff")
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for sensor_name, sensor_data in obs_dict.items():
        # Check if this is a visual observation (camera sensor)
        if "camera" in sensor_name and "sensor_param" not in sensor_name:
            # Ensure the data is in the right format (H, W, 3) with values in [0, 255]
            if (
                isinstance(sensor_data, np.ndarray)
                and len(sensor_data.shape) == 3
                and sensor_data.shape[2] == 3
            ):
                # Convert to uint8 if needed
                if sensor_data.dtype != np.uint8:
                    # Clip to valid range and convert
                    sensor_data = np.clip(sensor_data, 0, 255).astype(np.uint8)

                # Create PIL Image and save
                img = Image.fromarray(sensor_data)
                img_path = output_dir / f"{prefix}_{sensor_name}.png"
                img.save(img_path)

verify_and_compare_camera_observations

verify_and_compare_camera_observations(obs, sensor_suite, test_data_dir, test_data_prefix, expected_cameras, debug_images_dir=None, debug_prefix='obs', expected_shape=(480, 480, 3), atol=1.0, rtol=0.0, ignore_cameras=None, skip_depth_exact_match=True, ssim_threshold=0.9)

Verify observation structure and compare camera observations against saved test data.

This is a comprehensive helper for testing task observations that: 1. Verifies the observation structure (vectorized format) 2. Extracts camera sensors and checks their shapes 3. Compares them against saved test data 4. Optionally saves debug images for visual inspection 5. Verifies all expected cameras are present

Parameters:

Name Type Description Default
obs

Observation tuple from task.reset() or task.step()

required
sensor_suite

The task's sensor suite

required
test_data_dir

Path to directory containing test data files

required
test_data_prefix

Prefix for test data files (e.g., "rum_pick_obs_")

required
expected_cameras

List of expected camera sensor names

required
debug_images_dir

Optional path to save debug images. If None, no images are saved.

None
debug_prefix

Prefix for debug image filenames (default: "obs")

'obs'
expected_shape

Expected shape of camera observations (default: (480, 480, 3))

(480, 480, 3)
atol

Absolute tolerance for pixel value comparison (default: 1.0). For uint8 images [0-255], 1.0 allows single-pixel differences due to floating point precision or slight numerical variations.

1.0
rtol

Relative tolerance for comparison (default: 0.0)

0.0
ignore_cameras

Optional list of camera sensor names to skip during comparison

None
skip_depth_exact_match

Whether to skip pixel-exact depth comparison (default: True). When True, uses structural similarity (SSIM) on normalized depth for cross-platform robustness. When False, does pixel-exact comparison with edge masking (for local determinism tests). Depth rendering is NOT deterministic across platforms/GPUs.

True

Returns:

Name Type Description
tuple

(obs_dict, camera_sensors_found) for further testing if needed

Source code in molmo_spaces/utils/test_utils.py
def verify_and_compare_camera_observations(
    obs,
    sensor_suite,
    test_data_dir,
    test_data_prefix,
    expected_cameras,
    debug_images_dir=None,
    debug_prefix="obs",
    expected_shape=(480, 480, 3),
    atol=1.0,
    rtol=0.0,
    ignore_cameras=None,
    skip_depth_exact_match=True,
    ssim_threshold=0.9,
):
    """
    Verify observation structure and compare camera observations against saved test data.

    This is a comprehensive helper for testing task observations that:
    1. Verifies the observation structure (vectorized format)
    2. Extracts camera sensors and checks their shapes
    3. Compares them against saved test data
    4. Optionally saves debug images for visual inspection
    5. Verifies all expected cameras are present

    Args:
        obs: Observation tuple from task.reset() or task.step()
        sensor_suite: The task's sensor suite
        test_data_dir: Path to directory containing test data files
        test_data_prefix: Prefix for test data files (e.g., "rum_pick_obs_")
        expected_cameras: List of expected camera sensor names
        debug_images_dir: Optional path to save debug images. If None, no images are saved.
        debug_prefix: Prefix for debug image filenames (default: "obs")
        expected_shape: Expected shape of camera observations (default: (480, 480, 3))
        atol: Absolute tolerance for pixel value comparison (default: 1.0).
              For uint8 images [0-255], 1.0 allows single-pixel differences due to
              floating point precision or slight numerical variations.
        rtol: Relative tolerance for comparison (default: 0.0)
        ignore_cameras: Optional list of camera sensor names to skip during comparison
        skip_depth_exact_match: Whether to skip pixel-exact depth comparison (default: True).
              When True, uses structural similarity (SSIM) on normalized depth for cross-platform
              robustness. When False, does pixel-exact comparison with edge masking (for local
              determinism tests). Depth rendering is NOT deterministic across platforms/GPUs.

    Returns:
        tuple: (obs_dict, camera_sensors_found) for further testing if needed
    """
    test_data_dir = Path(test_data_dir)
    ignore_cameras = ignore_cameras or []

    # Verify observations structure (vectorized format)
    assert obs is not None, "Observations should not be None"
    assert isinstance(obs, tuple) and len(obs) == 2, (
        "Observations should be a tuple of (obs_list, info_dict)"
    )
    obs_list, info_dict = obs
    assert isinstance(obs_list, list), "First element should be list of observation dictionaries"
    assert len(obs_list) == 1, f"Expected 1 environment observation, got {len(obs_list)}"

    # Extract the single environment's observations
    obs_dict = obs_list[0]
    assert isinstance(obs_dict, dict), "Observation should be a dictionary"
    assert len(obs_dict) == len(sensor_suite.sensors), (
        f"Expected {len(sensor_suite.sensors)} sensors, got {len(obs_dict)}"
    )

    # Track which camera sensors we find and whether any assertions failed
    camera_sensors_found = []
    assertion_failed = False

    try:
        for sensor_name in sensor_suite.sensors:
            assert sensor_name in obs_dict, f"Sensor {sensor_name} not found in observations"

            if "camera" in sensor_name and "sensor_param" not in sensor_name:
                camera_sensors_found.append(sensor_name)

                # Skip comparison for ignored cameras
                if sensor_name in ignore_cameras:
                    print(f"[SKIP] Ignoring camera {sensor_name} as requested")
                    continue

                sensor_obs = obs_dict[sensor_name]

                # Verify basic properties
                assert sensor_obs is not None, f"Observation for {sensor_name} is None"

                # Handle depth sensors separately (2D) vs RGB sensors (3D)
                if sensor_name.endswith("_depth"):
                    # Depth sensors are 2D (H, W)
                    assert sensor_obs.ndim == 2, (
                        f"Depth sensor {sensor_name} should be 2D (H, W), got {sensor_obs.ndim}D"
                    )
                    # Verify depth shape (expected_shape is (W, H, C), extract W, H for depth)
                    sensor_shape_swapped = (sensor_obs.shape[1], sensor_obs.shape[0])
                    expected_depth_shape = (expected_shape[0], expected_shape[1])
                    assert sensor_shape_swapped == expected_depth_shape, (
                        f"Expected depth shape {expected_depth_shape}, got {sensor_obs.shape} (h/w swapped)"
                    )
                else:
                    # RGB sensors are 3D (H, W, C)
                    assert sensor_obs.ndim == 3, (
                        f"RGB sensor {sensor_name} should be 3D (H, W, C), got {sensor_obs.ndim}D"
                    )
                    sensor_shape_swapped = (
                        sensor_obs.shape[1],
                        sensor_obs.shape[0],
                        sensor_obs.shape[2],
                    )
                    assert sensor_shape_swapped == expected_shape, (
                        f"Expected shape {expected_shape}, got {sensor_obs.shape} (h/w swapped)"
                    )

                # Load and compare against saved test data for regression testing
                test_data_path = test_data_dir / f"{test_data_prefix}{sensor_name}.npy"
                expected_obs = np.load(test_data_path)

                # Handle depth sensor comparison differently
                if sensor_name.endswith("_depth"):
                    # Convert float16 to float32 if needed (depth saved as float16 to reduce file size)
                    if expected_obs.dtype == np.float16:
                        expected_obs = expected_obs.astype(np.float32)

                    # Skip pixel-exact comparison if requested (for cross-platform CI)
                    # Depth rendering is NOT deterministic across platforms/GPUs/drivers
                    if skip_depth_exact_match:
                        print(
                            f"[DEPTH] Using structural similarity (SSIM) for {sensor_name} "
                            f"(skip_depth_exact_match=True). Checks structure, not exact pixels."
                        )

                        # Basic sanity checks first
                        assert not np.any(np.isnan(sensor_obs)), (
                            f"{sensor_name} contains NaN values"
                        )
                        assert not np.any(np.isinf(sensor_obs)), (
                            f"{sensor_name} contains inf values"
                        )
                        assert np.any(sensor_obs > 0), f"{sensor_name} is all zeros"

                        # Normalize depth to [0, 255] range for SSIM comparison
                        # Use a reasonable depth range (e.g., 0-2m for wrist camera)
                        depth_min_for_viz = 0.0
                        depth_max_for_viz = 2.0

                        def normalize_depth_for_comparison(
                            depth, depth_min=depth_min_for_viz, depth_max=depth_max_for_viz
                        ):
                            """Normalize depth to uint8 [0, 255] for SSIM."""
                            depth_clipped = np.clip(depth, depth_min, depth_max)
                            depth_normalized = (depth_clipped - depth_min) / (depth_max - depth_min)
                            return (depth_normalized * 255).astype(np.uint8)

                        sensor_obs_normalized = normalize_depth_for_comparison(sensor_obs)
                        expected_obs_normalized = normalize_depth_for_comparison(expected_obs)

                        # Use SSIM to compare depth structure (same as RGB comparison)
                        # SSIM is robust to small numerical differences while catching major changes
                        ssim_score = ssim(
                            sensor_obs_normalized,
                            expected_obs_normalized,
                            data_range=255,
                        )

                        # Depth should have high structural similarity, same as RGB
                        # SSIM is designed to be robust to small numerical differences
                        MIN_DEPTH_SSIM_THRESHOLD = ssim_threshold

                        if ssim_score < MIN_DEPTH_SSIM_THRESHOLD:
                            # Calculate basic pixel stats for debugging
                            diff = np.abs(sensor_obs - expected_obs)
                            diff_max = np.max(diff)
                            diff_mean = np.mean(diff)
                            num_different = np.sum(diff > 0.01)  # >10mm difference
                            total_pixels = diff.size
                            percent_different = 100 * num_different / total_pixels

                            raise AssertionError(
                                f"Depth sensor {sensor_name} has low structural similarity to saved test data. "
                                f"SSIM score: {ssim_score:.4f} (threshold: {MIN_DEPTH_SSIM_THRESHOLD}). "
                                f"This suggests major rendering differences across platforms. "
                                f"Pixel-level stats: max diff={diff_max * 1000:.1f}mm, "
                                f"mean diff={diff_mean * 1000:.1f}mm, "
                                f"{percent_different:.1f}% pixels differ >10mm"
                            )

                        print(
                            f"  ✓ Depth SSIM: {ssim_score:.4f} (threshold: {MIN_DEPTH_SSIM_THRESHOLD})"
                        )
                        continue

                    # If we get here, we're doing exact comparison (for local determinism tests)
                    from molmo_spaces.utils.depth_utils import detect_depth_edges

                    # Detect edges in expected depth (where large errors are expected from small motion)
                    # Use a moderate threshold (50mm gradient) to catch occlusion boundaries
                    edge_mask = detect_depth_edges(expected_obs, gradient_threshold_mm=50.0)

                    # Create mask for smooth (non-edge) regions
                    smooth_mask = ~edge_mask

                    # Calculate differences
                    diff = np.abs(sensor_obs - expected_obs)

                    # Compare smooth regions with tight tolerance (5mm)
                    # Edge regions may have large differences due to slight motion/alignment
                    if np.sum(smooth_mask) > 0:
                        smooth_diff = diff[smooth_mask]
                        max_smooth_diff = np.max(smooth_diff)
                        mean_smooth_diff = np.mean(smooth_diff)

                        # Check smooth regions only (edges can differ due to motion)
                        if max_smooth_diff > 0.005:
                            # Calculate stats for error reporting
                            max_diff_overall = np.max(diff)
                            mean_diff_overall = np.mean(diff)
                            edge_pixels = np.sum(edge_mask)
                            smooth_pixels = np.sum(smooth_mask)

                            raise AssertionError(
                                f"Depth sensor {sensor_name} differs from saved test data. "
                                f"Smooth regions (non-edges): Max diff: {max_smooth_diff * 1000:.3f}mm, "
                                f"Mean diff: {mean_smooth_diff * 1000:.3f}mm "
                                f"({smooth_pixels:,} pixels). "
                                f"Overall: Max diff: {max_diff_overall * 1000:.3f}mm, "
                                f"Mean diff: {mean_diff_overall * 1000:.3f}mm "
                                f"({edge_pixels:,} edge pixels masked out)"
                            )
                    else:
                        raise AssertionError(
                            f"Depth sensor {sensor_name}: All pixels are edges, cannot compare"
                        )
                else:
                    # Compare RGB observations using SSIM
                    assert_observations_match(
                        sensor_obs,
                        expected_obs,
                        sensor_name,
                        atol=atol,
                        rtol=rtol,
                        ssim_threshold=ssim_threshold,
                    )

        # Verify we found all expected camera sensors
        for expected_camera in expected_cameras:
            assert expected_camera in camera_sensors_found, (
                f"Expected camera {expected_camera} not found in observations"
            )

    except AssertionError:
        assertion_failed = True
        raise
    finally:
        # Always save debug images for visual inspection (if requested)
        if debug_images_dir is not None and camera_sensors_found:
            expected_dict = {}
            for sensor_name in camera_sensors_found:
                # Skip loading ignored cameras for debug images too
                if sensor_name in ignore_cameras:
                    continue
                test_data_path = test_data_dir / f"{test_data_prefix}{sensor_name}.npy"
                expected_dict[sensor_name] = np.load(test_data_path)

            # Add FAILED_ prefix if test failed
            final_prefix = f"FAILED_{debug_prefix}" if assertion_failed else debug_prefix
            save_observation_comparison(
                obs_dict, expected_dict, debug_images_dir, prefix=final_prefix
            )

    return obs_dict, camera_sensors_found

verify_and_compare_camera_observations_after_steps

verify_and_compare_camera_observations_after_steps(obs_dict, sensor_suite, test_data_dir, test_data_prefix, expected_cameras, initial_obs_dict=None, debug_images_dir=None, debug_prefix='obs_after_steps', expected_shape=(480, 480, 3), atol=1.0, rtol=0.0, ignore_cameras=None, skip_depth_exact_match=True, ssim_threshold=0.9)

Verify and compare camera observations after running policy steps against saved test data.

Similar to verify_and_compare_camera_observations, but expects obs_dict directly rather than the tuple format from task.reset()/task.step().

Parameters:

Name Type Description Default
obs_dict

Dictionary of observations from a single environment

required
sensor_suite

The task's sensor suite

required
test_data_dir

Path to directory containing test data files

required
test_data_prefix

Prefix for test data files (e.g., "rum_pick_after_steps_")

required
expected_cameras

List of expected camera sensor names

required
initial_obs_dict

Optional dict of initial observations to verify that observations changed

None
debug_images_dir

Optional path to save debug images. If None, no images are saved.

None
debug_prefix

Prefix for debug image filenames (default: "obs_after_steps")

'obs_after_steps'
expected_shape

Expected shape of camera observations (w,h,c) (default: (480, 480, 3))

(480, 480, 3)
atol

Absolute tolerance for pixel value comparison (default: 1.0)

1.0
rtol

Relative tolerance for comparison (default: 0.0)

0.0
ignore_cameras

Optional list of camera sensor names to skip during comparison

None
skip_depth_exact_match

Whether to skip pixel-exact depth comparison (default: True). When True, uses structural similarity (SSIM) on normalized depth for cross-platform robustness. When False, does pixel-exact comparison with edge masking (for local determinism tests). Depth rendering is NOT deterministic across platforms/GPUs.

True

Returns:

Name Type Description
list

camera_sensors_found for further testing if needed

Source code in molmo_spaces/utils/test_utils.py
def verify_and_compare_camera_observations_after_steps(
    obs_dict,
    sensor_suite,
    test_data_dir,
    test_data_prefix,
    expected_cameras,
    initial_obs_dict=None,
    debug_images_dir=None,
    debug_prefix="obs_after_steps",
    expected_shape=(480, 480, 3),
    atol=1.0,
    rtol=0.0,
    ignore_cameras=None,
    skip_depth_exact_match=True,
    ssim_threshold=0.9,
):
    """
    Verify and compare camera observations after running policy steps against saved test data.

    Similar to verify_and_compare_camera_observations, but expects obs_dict directly
    rather than the tuple format from task.reset()/task.step().

    Args:
        obs_dict: Dictionary of observations from a single environment
        sensor_suite: The task's sensor suite
        test_data_dir: Path to directory containing test data files
        test_data_prefix: Prefix for test data files (e.g., "rum_pick_after_steps_")
        expected_cameras: List of expected camera sensor names
        initial_obs_dict: Optional dict of initial observations to verify that observations changed
        debug_images_dir: Optional path to save debug images. If None, no images are saved.
        debug_prefix: Prefix for debug image filenames (default: "obs_after_steps")
        expected_shape: Expected shape of camera observations (w,h,c) (default: (480, 480, 3))
        atol: Absolute tolerance for pixel value comparison (default: 1.0)
        rtol: Relative tolerance for comparison (default: 0.0)
        ignore_cameras: Optional list of camera sensor names to skip during comparison
        skip_depth_exact_match: Whether to skip pixel-exact depth comparison (default: True).
              When True, uses structural similarity (SSIM) on normalized depth for cross-platform
              robustness. When False, does pixel-exact comparison with edge masking (for local
              determinism tests). Depth rendering is NOT deterministic across platforms/GPUs.

    Returns:
        list: camera_sensors_found for further testing if needed
    """
    test_data_dir = Path(test_data_dir)
    ignore_cameras = ignore_cameras or []

    # Verify observations structure
    assert isinstance(obs_dict, dict), "Observation should be a dictionary"
    assert len(obs_dict) == len(sensor_suite.sensors), (
        f"Expected {len(sensor_suite.sensors)} sensors, got {len(obs_dict)}"
    )

    # Track which camera sensors we find and whether any assertions failed
    camera_sensors_found = []
    assertion_failed = False

    try:
        for sensor_name in sensor_suite.sensors:
            assert sensor_name in obs_dict, f"Sensor {sensor_name} not found in observations"

            if "camera" in sensor_name and "sensor_param" not in sensor_name:
                camera_sensors_found.append(sensor_name)

                # Skip comparison for ignored cameras
                if sensor_name in ignore_cameras:
                    print(f"[SKIP] Ignoring camera {sensor_name} as requested")
                    continue

                sensor_obs = obs_dict[sensor_name]

                # Handle depth sensors separately (2D) vs RGB sensors (3D)
                if sensor_name.endswith("_depth"):
                    # Depth sensors are 2D (H, W)
                    assert sensor_obs.ndim == 2, (
                        f"Depth sensor {sensor_name} should be 2D (H, W), got {sensor_obs.ndim}D"
                    )
                    # Verify depth shape (expected_shape is (W, H, C), extract W, H for depth)
                    sensor_shape_swapped = (sensor_obs.shape[1], sensor_obs.shape[0])
                    expected_depth_shape = (expected_shape[0], expected_shape[1])
                    assert sensor_shape_swapped == expected_depth_shape, (
                        f"Expected depth shape {expected_depth_shape}, got {sensor_obs.shape} (h/w swapped)"
                    )
                else:
                    # RGB sensors are 3D (H, W, C)
                    assert sensor_obs.ndim == 3, (
                        f"RGB sensor {sensor_name} should be 3D (H, W, C), got {sensor_obs.ndim}D"
                    )
                    sensor_shape_swapped = (
                        sensor_obs.shape[1],
                        sensor_obs.shape[0],
                        sensor_obs.shape[2],
                    )
                    assert sensor_shape_swapped == expected_shape, (
                        f"Expected shape {expected_shape}, got {sensor_obs.shape} (h/w swapped)"
                    )

                # Load and compare against saved test data for regression testing
                test_data_path = test_data_dir / f"{test_data_prefix}{sensor_name}.npy"
                expected_obs = np.load(test_data_path)

                # Handle depth sensor comparison differently
                if sensor_name.endswith("_depth"):
                    # Convert float16 to float32 if needed (depth saved as float16 to reduce file size)
                    if expected_obs.dtype == np.float16:
                        expected_obs = expected_obs.astype(np.float32)

                    # Skip pixel-exact comparison if requested (for cross-platform CI)
                    # Depth rendering is NOT deterministic across platforms/GPUs/drivers
                    if skip_depth_exact_match:
                        print(
                            f"[DEPTH] Using structural similarity (SSIM) for {sensor_name} "
                            f"(skip_depth_exact_match=True). Checks structure, not exact pixels."
                        )

                        # Basic sanity checks first
                        assert not np.any(np.isnan(sensor_obs)), (
                            f"{sensor_name} contains NaN values"
                        )
                        assert not np.any(np.isinf(sensor_obs)), (
                            f"{sensor_name} contains inf values"
                        )
                        assert np.any(sensor_obs > 0), f"{sensor_name} is all zeros"

                        # Normalize depth to [0, 255] range for SSIM comparison
                        # Use a reasonable depth range (e.g., 0-2m for wrist camera)
                        depth_min_for_viz = 0.0
                        depth_max_for_viz = 2.0

                        def normalize_depth_for_comparison(
                            depth, depth_min=depth_min_for_viz, depth_max=depth_max_for_viz
                        ):
                            """Normalize depth to uint8 [0, 255] for SSIM."""
                            depth_clipped = np.clip(depth, depth_min, depth_max)
                            depth_normalized = (depth_clipped - depth_min) / (depth_max - depth_min)
                            return (depth_normalized * 255).astype(np.uint8)

                        sensor_obs_normalized = normalize_depth_for_comparison(sensor_obs)
                        expected_obs_normalized = normalize_depth_for_comparison(expected_obs)

                        # Use SSIM to compare depth structure (same as RGB comparison)
                        # SSIM is robust to small numerical differences while catching major changes
                        ssim_score = ssim(
                            sensor_obs_normalized,
                            expected_obs_normalized,
                            data_range=255,
                        )

                        # Depth should have high structural similarity, same as RGB
                        # SSIM is designed to be robust to small numerical differences
                        MIN_DEPTH_SSIM_THRESHOLD = ssim_threshold

                        if ssim_score < MIN_DEPTH_SSIM_THRESHOLD:
                            # Calculate basic pixel stats for debugging
                            diff = np.abs(sensor_obs - expected_obs)
                            diff_max = np.max(diff)
                            diff_mean = np.mean(diff)
                            num_different = np.sum(diff > 0.01)  # >10mm difference
                            total_pixels = diff.size
                            percent_different = 100 * num_different / total_pixels

                            raise AssertionError(
                                f"Depth sensor {sensor_name} has low structural similarity to saved test data. "
                                f"SSIM score: {ssim_score:.4f} (threshold: {MIN_DEPTH_SSIM_THRESHOLD}). "
                                f"This suggests major rendering differences across platforms. "
                                f"Pixel-level stats: max diff={diff_max * 1000:.1f}mm, "
                                f"mean diff={diff_mean * 1000:.1f}mm, "
                                f"{percent_different:.1f}% pixels differ >10mm"
                            )

                        print(
                            f"  ✓ Depth SSIM: {ssim_score:.4f} (threshold: {MIN_DEPTH_SSIM_THRESHOLD})"
                        )
                        continue

                    # If we get here, we're doing exact comparison (for local determinism tests)
                    from molmo_spaces.utils.depth_utils import detect_depth_edges

                    # Detect edges in expected depth (where large errors are expected from small motion)
                    # Use a moderate threshold (50mm gradient) to catch occlusion boundaries
                    edge_mask = detect_depth_edges(expected_obs, gradient_threshold_mm=50.0)

                    # Create mask for smooth (non-edge) regions
                    smooth_mask = ~edge_mask

                    # Calculate differences
                    diff = np.abs(sensor_obs - expected_obs)

                    # Compare smooth regions with tight tolerance (5mm)
                    # Edge regions may have large differences due to slight motion/alignment
                    if np.sum(smooth_mask) > 0:
                        smooth_diff = diff[smooth_mask]
                        max_smooth_diff = np.max(smooth_diff)
                        mean_smooth_diff = np.mean(smooth_diff)

                        # Check smooth regions only (edges can differ due to motion)
                        if max_smooth_diff > 0.005:
                            # Calculate stats for error reporting
                            max_diff_overall = np.max(diff)
                            mean_diff_overall = np.mean(diff)
                            edge_pixels = np.sum(edge_mask)
                            smooth_pixels = np.sum(smooth_mask)

                            raise AssertionError(
                                f"Depth sensor {sensor_name} differs from saved test data. "
                                f"Smooth regions (non-edges): Max diff: {max_smooth_diff * 1000:.3f}mm, "
                                f"Mean diff: {mean_smooth_diff * 1000:.3f}mm "
                                f"({smooth_pixels:,} pixels). "
                                f"Overall: Max diff: {max_diff_overall * 1000:.3f}mm, "
                                f"Mean diff: {mean_diff_overall * 1000:.3f}mm "
                                f"({edge_pixels:,} edge pixels masked out)"
                            )
                    else:
                        raise AssertionError(
                            f"Depth sensor {sensor_name}: All pixels are edges, cannot compare"
                        )
                else:
                    # Compare RGB observations using SSIM
                    assert_observations_match(
                        sensor_obs,
                        expected_obs,
                        sensor_name,
                        atol=atol,
                        rtol=rtol,
                        ssim_threshold=ssim_threshold,
                    )

        # Verify we found all expected camera sensors
        for expected_camera in expected_cameras:
            assert expected_camera in camera_sensors_found, (
                f"Expected camera {expected_camera} not found in observations"
            )

        # Verify observations changed from initial (if initial observations provided)
        if initial_obs_dict is not None:
            for sensor_name in camera_sensors_found:
                # Skip ignored cameras from this verification too
                if sensor_name in ignore_cameras:
                    continue

                if sensor_name in initial_obs_dict:
                    initial_obs = initial_obs_dict[sensor_name]
                    final_obs = obs_dict[sensor_name]

                    # Calculate pixel difference
                    diff = np.abs(final_obs.astype(np.int32) - initial_obs.astype(np.int32))
                    num_different_pixels = np.sum(diff > 0)
                    total_pixels = diff.size
                    percent_changed = 100 * num_different_pixels / total_pixels

                    # Observations should have changed at least slightly (>0.1% of pixels)
                    assert percent_changed > 0.1, (
                        f"Sensor {sensor_name}: Observations didn't change after running steps. "
                        f"Only {num_different_pixels}/{total_pixels} pixels ({percent_changed:.4f}%) changed. "
                        f"This suggests the robot didn't move or the camera view didn't update."
                    )

                    print(f"[OBS CHANGE] {sensor_name}: {percent_changed:.2f}% of pixels changed")

    except AssertionError:
        assertion_failed = True
        raise
    finally:
        # Always save debug images for visual inspection (if requested)
        if debug_images_dir is not None and camera_sensors_found:
            expected_dict = {}
            for sensor_name in camera_sensors_found:
                # Skip loading ignored cameras for debug images too
                if sensor_name in ignore_cameras:
                    continue
                test_data_path = test_data_dir / f"{test_data_prefix}{sensor_name}.npy"
                expected_dict[sensor_name] = np.load(test_data_path)

            # Add FAILED_ prefix if test failed
            final_prefix = f"FAILED_{debug_prefix}" if assertion_failed else debug_prefix
            save_observation_comparison(
                obs_dict, expected_dict, debug_images_dir, prefix=final_prefix
            )

    return camera_sensors_found

verify_video_fps

verify_video_fps(dir: Path, expected_fps: float)

Assert all videos in a directory have the expected FPS.

Source code in molmo_spaces/utils/test_utils.py
def verify_video_fps(dir: Path, expected_fps: float):
    """Assert all videos in a directory have the expected FPS."""
    for vid_file in dir.glob("*.mp4"):
        vr = decord.VideoReader(str(vid_file))
        fps = vr.get_avg_fps()
        assert np.isclose(fps, expected_fps, atol=1e-2), (
            f"Expected {vid_file} to be {expected_fps} fps, got {fps}"
        )

video_utils

Copied from video2sim_pipeline/video2sim/utils/video_utils.py

Functions:

Name Description
ffmpeg_save_video

Save a video using ffmpeg.

resize_with_padding

Resize an image to fit within the target size while maintaining its original aspect ratio.

ffmpeg_save_video

ffmpeg_save_video(frames, output_path: str, fps: float = 30.0, codec: str = 'libx264', quality: int = 23, pix_fmt='rgb24')

Save a video using ffmpeg.

Parameters:

Name Type Description Default
frames

Video frames as numpy array (T, H, W, 3) or torch tensor (T, 3, H, W)

required
output_path str

Path to save the video file

required
fps float

Frames per second

30.0
codec str

Video codec to use

'libx264'
quality int

CRF value (lower is better quality, 18-28 is reasonable)

23
Source code in molmo_spaces/utils/video_utils.py
def ffmpeg_save_video(
    frames,
    output_path: str,
    fps: float = 30.0,
    codec: str = "libx264",
    quality: int = 23,  # Lower CRF means higher quality (18-28 is good range)
    pix_fmt="rgb24",  # opencv
):
    """
    Save a video using ffmpeg.

    Args:
        frames: Video frames as numpy array (T, H, W, 3) or torch tensor (T, 3, H, W)
        output_path: Path to save the video file
        fps: Frames per second
        codec: Video codec to use
        quality: CRF value (lower is better quality, 18-28 is reasonable)
    """
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Convert torch tensor to numpy if needed
    if isinstance(frames, torch.Tensor):
        frames = frames.cpu().numpy()
        if frames.shape[1] == 3:  # Convert from TCHW to THWC
            frames = rearrange(frames, "T C H W -> T H W C")

    # Ensure frames are uint8
    if isinstance(frames, list):
        if frames[0].dtype != np.uint8:
            frames = [frame.astype(np.uint8) for frame in frames]
        frames = np.array(frames)
    else:
        if frames.dtype != np.uint8:
            frames = (frames * 255).astype(np.uint8)

    assert frames.ndim == 4 and frames.shape[-1] == 3, (
        f"Expected THWC format, got shape {frames.shape}"
    )

    # Set up ffmpeg process
    process = (
        ffmpeg.input(
            "pipe:",
            format="rawvideo",
            pix_fmt=pix_fmt,  # "bgr24", #"rgb24",
            s=f"{frames.shape[2]}x{frames.shape[1]}",
            r=fps,
        )
        .output(output_path, pix_fmt="yuv420p", vcodec=codec, crf=quality)
        .overwrite_output()
        .run_async(pipe_stdin=True, pipe_stderr=True)  # .run_async(pipe_stdin=True, quiet=True)
    )

    # # Write frames
    for i, frame in enumerate(frames):
        try:
            process.stdin.write(frame.tobytes())
        except BrokenPipeError:
            print(f"[FFMPEG ERROR] Broken pipe after writing frame {i}.")
            stderr_output = process.stderr.read().decode() if process.stderr else "No stderr"
            print(f"[FFMPEG STDERR]\n{stderr_output}")

    process.stdin.close()
    process.wait()

    return output_path

resize_with_padding

resize_with_padding(image, target_width, target_height, pad_color=(0, 0, 0))

Resize an image to fit within the target size while maintaining its original aspect ratio. Padding (letterbox) is added to ensure the final image matches the target dimensions.

Parameters:

Name Type Description Default
image array

Input image.

required
target_width int

Desired width.

required
target_height int

Desired height.

required
pad_color tuple

Color for the padding (default is black).

(0, 0, 0)

Returns:

Type Description

np.array: Resized image with padding.

Source code in molmo_spaces/utils/video_utils.py
def resize_with_padding(image, target_width, target_height, pad_color=(0, 0, 0)):
    """
    Resize an image to fit within the target size while maintaining its original aspect ratio.
    Padding (letterbox) is added to ensure the final image matches the target dimensions.

    Args:
        image (np.array): Input image.
        target_width (int): Desired width.
        target_height (int): Desired height.
        pad_color (tuple): Color for the padding (default is black).

    Returns:
        np.array: Resized image with padding.
    """
    h, w = image.shape[:2]
    scale = min(target_width / w, target_height / h)
    new_w, new_h = int(w * scale), int(h * scale)

    resized_image = cv2.resize(image, (new_w, new_h))

    pad_w = target_width - new_w
    pad_h = target_height - new_h

    # Calculate padding for each side
    top = pad_h // 2
    bottom = pad_h - top
    left = pad_w // 2
    right = pad_w - left

    padded_image = cv2.copyMakeBorder(
        resized_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=pad_color
    )
    return padded_image