Skip to content

DataFiltering

This guide explains how to perform data filtering based on error patterns using the MLdebugger SDK.

Overview

The DataFiltering flow is a workflow for efficiently selecting data from pool data based on evaluation results. Two methods are supported:

  • Batch Post-Processing Method: Select data in bulk using the query() method after running inference on all data
  • Real-time Filtering Method: Evaluate filter conditions in real-time during inference

The DataFilter class used in STEP 1 varies by task:

Task DataFilter Class
Classification ClassificationDataFilter
Object Detection ObjectDetectionDataFilter
3D Object Detection ObjectDetection3DDataFilter

Prerequisite: Tracing + Evaluation must be complete and result_name must be obtained.

STEP 1: Initialize DataFilter

Batch Post-Processing Method (Basic)

from ml_debugger.data_filter import ClassificationDataFilter

data_filter = ClassificationDataFilter(
    model,                                      # Model to evaluate
    model_name="resnet18",                      # model_name used with Tracer
    version_name="v1",                          # version_name used with Tracer
    result_name="resnet18_v1_classification_v1_20251219",  # result_name of evaluation
)

Real-time Filtering Method

If you want to filter data in real-time during inference, use the filter_config parameter.

data_filter = ClassificationDataFilter(
    model,
    model_name="resnet18",
    version_name="v1",
    result_name="resnet18_v1_classification_v1_20251219",
    filter_config={"target_zones": ["critical_hotspot"]},  # Target Critical Hotspot zone
)

Selecting target_zones

  • critical_hotspot: Data where the model makes incorrect predictions with high confidence. Most important for quality improvement.
  • hotspot: General error zone. Use when you want to collect a wide range of error patterns.
  • stable_coverage: Data where the model makes correct predictions with high confidence. Can be used to thin out data with low learning contribution.
  • Multiple zones can be specified: {"target_zones": ["critical_hotspot", "hotspot"]} (OR condition)

Details and conditions for available zones can be viewed on the Heatmap screen in the GUI.

Batch Post-Processing Method (Basic)

from ml_debugger.data_filter import ObjectDetectionDataFilter

data_filter = ObjectDetectionDataFilter(
    model,                                      # Object Detection model
    model_name="faster_rcnn",                   # model_name used with Tracer
    version_name="v1",                          # version_name used with Tracer
    result_name="faster_rcnn_v1_od_v1_20251219",  # result_name of evaluation
)

Real-time Filtering Method

For Object Detection, pass a BBoxStrategy dict to filter_config. See the Object Detection tab in STEP 3 for details.

data_filter = ObjectDetectionDataFilter(
    model,
    model_name="faster_rcnn",
    version_name="v1",
    result_name="faster_rcnn_v1_od_v1_20251219",
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

Batch Post-Processing Method (Basic)

from ml_debugger.data_filter import ObjectDetection3DDataFilter

data_filter = ObjectDetection3DDataFilter(
    model,                                      # 3D Object Detection model
    model_name="centerpoint",                   # model_name used with Tracer
    version_name="v1",                          # version_name used with Tracer
    result_name="centerpoint_v1_od3d_v1_20251219",  # result_name of evaluation
)

Real-time Filtering Method

Like 2D Object Detection, pass a BBoxStrategy dict to filter_config. See the Object Detection tab in STEP 3 for details.

data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="centerpoint",
    version_name="v1",
    result_name="centerpoint_v1_od3d_v1_20251219",
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

Getting result_name

result_name can be obtained from the Result object returned by Evaluator.request_evaluation().

result = evaluator.request_evaluation()
result_name = result.result_name

STEP 2: Inference on Pool Data

Batch Post-Processing Method

Run inference on unlabeled pool data and collect internal features. When filter_config is not set, all filter_flags will be None.

import torch

for input_data, _, ids in pool_dataloader:
    input_data = input_data.to(device)

    with torch.no_grad():
        model_output, filter_flags = data_filter(
            input_data,                             # Input data
            input_ids=ids,                          # Data identifiers
        )
        # When filter_config is not set, filter_flags are all None

Labels Not Required

DataFilter does not require ground truth labels. It only collects inference results and internal features from pool data.

3D Object Detection Input Format

For 3D Object Detection, the input data uses a dict format (points, img, etc.). See the 3D Object Detection tab in Tracing + Evaluation for details.

Real-time Filtering Method

When filter_config is specified, you can use filter_flags to perform real-time branching based on whether each data matches the filter condition.

import torch

new_data_indices = []  # Indices of data that matched the filter condition

for input_data, _, ids in pool_dataloader:
    input_data = input_data.to(device)

    with torch.no_grad():
        model_output, filter_flags = data_filter(
            input_data,
            input_ids=ids,
        )

    # Branch processing using filter_flags
    for id, flag in zip(ids, filter_flags):
        if flag is True:
            # Matches filter condition → Route to add to dataset
            new_data_indices.append(id)

print(f"Total {len(new_data_indices)} data points matched filter condition")

filter_flags Values

  • True: Data matches the filter condition (target for extraction)
  • False: Data does not match the filter condition
  • None: filter_config is not set

STEP 3: Query Data (Batch Post-Processing Method)

Select data from collected data based on the specified strategy.

Sort Strategy

Sorts data by the specified metric and retrieves the top N items.

queried_ids = data_filter.query(
    n_data=50,                      # Number of data to retrieve
    strategy="high_error_proba",    # Highest error probability first
)
Strategy Description
"high_error_proba" Prioritize data with high error probability (Hard Example Mining)
"low_error_proba" Prioritize data with low error probability

Filter Strategy

Retrieves up to N items that match the specified conditions. Conditions are specified in dictionary format.

# Zone specification: Get up to 50 data points from Hotspot zones
queried_ids = data_filter.query(
    n_data=50,
    strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)

# Threshold specification: Get up to 50 data points with error probability >= 0.8
queried_ids = data_filter.query(
    n_data=50,
    strategy={"conditions": [{"error_proba": ">=0.8"}]},
)

# Combined conditions: Multiple threshold conditions (AND)
queried_ids = data_filter.query(
    n_data=50,
    strategy={"conditions": [{"error_proba": ">=0.5", "score": "<0.7"}]},
)

Choosing Between Zones and Threshold Conditions

  • Zone specification: Intuitive selection based on error patterns
  • Threshold specification: Use when you want custom selection with fine-grained conditions
  • When both are specified, they are evaluated as OR conditions

Available Zones

Zone Name Issue Category Description
stable_coverage Highly Stable High-quality output with stable internal features
operational_coverage Stable Operationally acceptable with stable internal features
hotspot Unstable Internal features are unstable, predictions fluctuate
recessive_hotspot Under-Confidence Low model confidence, errors are predictable
critical_hotspot Over-Confidence High error probability but high confidence predictions
aleatoric_hotspot Outlier Insufficient features, model has difficulty learning

For Object Detection (2D / 3D), the strategy parameter accepts a sort strategy or a BBoxStrategy dict. You can customize how per-bbox error probabilities are aggregated to the image (frame) level.

Sort Strategy

queried_ids = data_filter.query(
    n_data=50,
    strategy="high_error_proba",    # Highest error probability first
)
Strategy Description
"high_error_proba" Prioritize data with high error probability
"low_error_proba" Prioritize data with low error probability

BBoxStrategy Dict

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",  # Error column to aggregate
        "top_n": 3,                          # Select top N bboxes
        "aggregation": "mean",               # Aggregation method (mean / median)
    },
)
Key Type Default Description
target_column str "error_proba" Error probability column to aggregate
top_n Optional[int] None Select top N bboxes (None for all)
aggregation str "mean" Aggregation method ("mean" / "median")
weight_column Optional[str] None Column for weighted aggregation
bbox_filter Optional[dict] None Per-bbox filter condition
sort str "desc" Sort order ("asc" / "desc")

String Strategy Compatibility

Sort strategies like "high_error_proba" / "low_error_proba" work the same for both Classification and Object Detection.

Images with No BBox Detections

As of v0.3.1, images where no bounding boxes are detected (due to NMS settings) are excluded from query results. If many such images exist, the number of returned data points may be less than the specified n_data.

Optional Parameters

queried_ids = data_filter.query(
    n_data=50,
    strategy="high_error_proba",
    dataset_type="pool",            # dataset_type to filter
    type_cast=int,                  # Type conversion for input_id
)

Complete Sample Code

import os
import torch
from ml_debugger.data_filter import ClassificationDataFilter

# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # Trained model
pool_dataloader = ...  # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- Batch Post-Processing Method ---
data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
)

for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())

queried_ids = data_filter.query(
    n_data=100,
    strategy="high_error_proba",
    type_cast=int,
)

# --- Real-time Filtering Method ---
data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"target_zones": ["critical_hotspot"]},
)

critical_data_indices = []
for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
    for idx, flag in zip(indices, filter_flags):
        if flag is True:
            critical_data_indices.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetectionDataFilter

# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # Object Detection model
pool_dataloader = ...  # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- Batch Post-Processing Method ---
data_filter = ObjectDetectionDataFilter(
    model,
    model_name="my_od_model",
    version_name="v1",
    result_name=result.result_name,
)

for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

# --- Real-time Filtering Method ---
data_filter = ObjectDetectionDataFilter(
    model,
    model_name="my_od_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

high_error_images = []
for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
    for idx, flag in zip(indices, filter_flags):
        if flag is True:
            high_error_images.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetection3DDataFilter

# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # 3D Object Detection model
pool_dataloader = ...  # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- Batch Post-Processing Method ---
data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="my_3d_od_model",
    version_name="v1",
    result_name=result.result_name,
)

for points, _, frame_ids in pool_dataloader:
    points = points.to(device)
    model_output, filter_flags = data_filter(points, input_ids=frame_ids)

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

# --- Real-time Filtering Method ---
data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="my_3d_od_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

high_error_frames = []
for points, _, frame_ids in pool_dataloader:
    points = points.to(device)
    model_output, filter_flags = data_filter(points, input_ids=frame_ids)
    for fid, flag in zip(frame_ids, filter_flags):
        if flag is True:
            high_error_frames.append(fid)

Active Learning Workflow

DataFiltering is designed to be used within an Active Learning loop.

# Active Learning loop
for iteration in range(n_iterations):
    # 1. Train with current training data
    train_model(model, train_dataloader)

    # 2. Run evaluation
    tracer = ClassificationTracer(model, model_name, version_name)
    # ... collect data ...
    result = evaluator.request_evaluation()

    # 3. Query from pool data
    data_filter = ClassificationDataFilter(
        model, model_name, version_name, result.result_name
    )
    # ... pool data inference ...
    queried_ids = data_filter.query(n_data=100, strategy="high_error_proba")

    # 4. Add selected data to training data
    train_indices += queried_ids
    pool_indices = list(set(pool_indices) - set(queried_ids))

See Active Learning Use Case for detailed implementation examples.

Specifying n_epoch

When multiple training runs have been performed with the same version_name, explicitly specifying n_epoch is effective for avoiding data conflicts.

Specifying n_epoch at DataFilter Initialization

data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
    n_epoch=5,  # Use data from a specific epoch
)

Why Specifying n_epoch is Important

Since n_epoch="latest" determines the latest epoch based on timestamp, if you restart epochs from 1 with the same version_name, unintended data may be used.

# Problematic case
# If you trained with epoch 0-9 in the past, then retrained with epoch 0-4
# latest points to epoch=4 (not the past epoch=9)

# Solution 1: Explicitly specify n_epoch
data_filter = ClassificationDataFilter(..., n_epoch=9)

# Solution 2: Use a new version_name
data_filter = ClassificationDataFilter(..., version_name="v2")

Recommended Solutions

  • Use a new version_name when starting new training
  • Explicitly specify n_epoch when you want to use data from a specific epoch
  • Specify result_name to ensure data associated with the evaluation result is used

See model_name / version_name for details.

Next Steps

  • Logging - Inference log collection during operation