DataFiltering
This guide explains how to perform data filtering based on error patterns using the MLdebugger SDK.
Overview
The DataFiltering flow is a workflow for efficiently selecting data from pool data based on evaluation results. Two methods are supported:
- Batch Post-Processing Method: Select data in bulk using the
query()method after running inference on all data - Real-time Filtering Method: Evaluate filter conditions in real-time during inference
The DataFilter class used in STEP 1 varies by task:
| Task | DataFilter Class |
|---|---|
| Classification | ClassificationDataFilter |
| Object Detection | ObjectDetectionDataFilter |
| 3D Object Detection | ObjectDetection3DDataFilter |
Prerequisite: Tracing + Evaluation must be complete and result_name must be obtained.
STEP 1: Initialize DataFilter
Batch Post-Processing Method (Basic)
from ml_debugger.data_filter import ClassificationDataFilter
data_filter = ClassificationDataFilter(
model, # Model to evaluate
model_name="resnet18", # model_name used with Tracer
version_name="v1", # version_name used with Tracer
result_name="resnet18_v1_classification_v1_20251219", # result_name of evaluation
)
Real-time Filtering Method
If you want to filter data in real-time during inference, use the filter_config parameter.
data_filter = ClassificationDataFilter(
model,
model_name="resnet18",
version_name="v1",
result_name="resnet18_v1_classification_v1_20251219",
filter_config={"target_zones": ["critical_hotspot"]}, # Target Critical Hotspot zone
)
Selecting target_zones
critical_hotspot: Data where the model makes incorrect predictions with high confidence. Most important for quality improvement.hotspot: General error zone. Use when you want to collect a wide range of error patterns.stable_coverage: Data where the model makes correct predictions with high confidence. Can be used to thin out data with low learning contribution.- Multiple zones can be specified:
{"target_zones": ["critical_hotspot", "hotspot"]}(OR condition)
Details and conditions for available zones can be viewed on the Heatmap screen in the GUI.
Batch Post-Processing Method (Basic)
from ml_debugger.data_filter import ObjectDetectionDataFilter
data_filter = ObjectDetectionDataFilter(
model, # Object Detection model
model_name="faster_rcnn", # model_name used with Tracer
version_name="v1", # version_name used with Tracer
result_name="faster_rcnn_v1_od_v1_20251219", # result_name of evaluation
)
Real-time Filtering Method
For Object Detection, pass a BBoxStrategy dict to filter_config.
See the Object Detection tab in STEP 3 for details.
data_filter = ObjectDetectionDataFilter(
model,
model_name="faster_rcnn",
version_name="v1",
result_name="faster_rcnn_v1_od_v1_20251219",
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
Batch Post-Processing Method (Basic)
from ml_debugger.data_filter import ObjectDetection3DDataFilter
data_filter = ObjectDetection3DDataFilter(
model, # 3D Object Detection model
model_name="centerpoint", # model_name used with Tracer
version_name="v1", # version_name used with Tracer
result_name="centerpoint_v1_od3d_v1_20251219", # result_name of evaluation
)
Real-time Filtering Method
Like 2D Object Detection, pass a BBoxStrategy dict to filter_config.
See the Object Detection tab in STEP 3 for details.
data_filter = ObjectDetection3DDataFilter(
model,
model_name="centerpoint",
version_name="v1",
result_name="centerpoint_v1_od3d_v1_20251219",
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
Getting result_name
result_name can be obtained from the Result object returned by Evaluator.request_evaluation().
result = evaluator.request_evaluation()
result_name = result.result_name
STEP 2: Inference on Pool Data
Batch Post-Processing Method
Run inference on unlabeled pool data and collect internal features.
When filter_config is not set, all filter_flags will be None.
import torch
for input_data, _, ids in pool_dataloader:
input_data = input_data.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(
input_data, # Input data
input_ids=ids, # Data identifiers
)
# When filter_config is not set, filter_flags are all None
Labels Not Required
DataFilter does not require ground truth labels. It only collects inference results and internal features from pool data.
3D Object Detection Input Format
For 3D Object Detection, the input data uses a dict format (points, img, etc.).
See the 3D Object Detection tab in Tracing + Evaluation for details.
Real-time Filtering Method
When filter_config is specified, you can use filter_flags to perform real-time branching based on whether each data matches the filter condition.
import torch
new_data_indices = [] # Indices of data that matched the filter condition
for input_data, _, ids in pool_dataloader:
input_data = input_data.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(
input_data,
input_ids=ids,
)
# Branch processing using filter_flags
for id, flag in zip(ids, filter_flags):
if flag is True:
# Matches filter condition → Route to add to dataset
new_data_indices.append(id)
print(f"Total {len(new_data_indices)} data points matched filter condition")
filter_flags Values
True: Data matches the filter condition (target for extraction)False: Data does not match the filter conditionNone:filter_configis not set
STEP 3: Query Data (Batch Post-Processing Method)
Select data from collected data based on the specified strategy.
Sort Strategy
Sorts data by the specified metric and retrieves the top N items.
queried_ids = data_filter.query(
n_data=50, # Number of data to retrieve
strategy="high_error_proba", # Highest error probability first
)
| Strategy | Description |
|---|---|
"high_error_proba" |
Prioritize data with high error probability (Hard Example Mining) |
"low_error_proba" |
Prioritize data with low error probability |
Filter Strategy
Retrieves up to N items that match the specified conditions. Conditions are specified in dictionary format.
# Zone specification: Get up to 50 data points from Hotspot zones
queried_ids = data_filter.query(
n_data=50,
strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)
# Threshold specification: Get up to 50 data points with error probability >= 0.8
queried_ids = data_filter.query(
n_data=50,
strategy={"conditions": [{"error_proba": ">=0.8"}]},
)
# Combined conditions: Multiple threshold conditions (AND)
queried_ids = data_filter.query(
n_data=50,
strategy={"conditions": [{"error_proba": ">=0.5", "score": "<0.7"}]},
)
Choosing Between Zones and Threshold Conditions
- Zone specification: Intuitive selection based on error patterns
- Threshold specification: Use when you want custom selection with fine-grained conditions
- When both are specified, they are evaluated as OR conditions
Available Zones
| Zone Name | Issue Category | Description |
|---|---|---|
stable_coverage |
Highly Stable | High-quality output with stable internal features |
operational_coverage |
Stable | Operationally acceptable with stable internal features |
hotspot |
Unstable | Internal features are unstable, predictions fluctuate |
recessive_hotspot |
Under-Confidence | Low model confidence, errors are predictable |
critical_hotspot |
Over-Confidence | High error probability but high confidence predictions |
aleatoric_hotspot |
Outlier | Insufficient features, model has difficulty learning |
For Object Detection (2D / 3D), the strategy parameter accepts a sort strategy or a BBoxStrategy dict.
You can customize how per-bbox error probabilities are aggregated to the image (frame) level.
Sort Strategy
queried_ids = data_filter.query(
n_data=50,
strategy="high_error_proba", # Highest error probability first
)
| Strategy | Description |
|---|---|
"high_error_proba" |
Prioritize data with high error probability |
"low_error_proba" |
Prioritize data with low error probability |
BBoxStrategy Dict
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba", # Error column to aggregate
"top_n": 3, # Select top N bboxes
"aggregation": "mean", # Aggregation method (mean / median)
},
)
| Key | Type | Default | Description |
|---|---|---|---|
target_column |
str |
"error_proba" |
Error probability column to aggregate |
top_n |
Optional[int] |
None |
Select top N bboxes (None for all) |
aggregation |
str |
"mean" |
Aggregation method ("mean" / "median") |
weight_column |
Optional[str] |
None |
Column for weighted aggregation |
bbox_filter |
Optional[dict] |
None |
Per-bbox filter condition |
sort |
str |
"desc" |
Sort order ("asc" / "desc") |
String Strategy Compatibility
Sort strategies like "high_error_proba" / "low_error_proba" work the same for both Classification and Object Detection.
Images with No BBox Detections
As of v0.3.1, images where no bounding boxes are detected (due to NMS settings) are excluded from query results.
If many such images exist, the number of returned data points may be less than the specified n_data.
Optional Parameters
queried_ids = data_filter.query(
n_data=50,
strategy="high_error_proba",
dataset_type="pool", # dataset_type to filter
type_cast=int, # Type conversion for input_id
)
Complete Sample Code
import os
import torch
from ml_debugger.data_filter import ClassificationDataFilter
# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # Trained model
pool_dataloader = ... # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- Batch Post-Processing Method ---
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
)
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
queried_ids = data_filter.query(
n_data=100,
strategy="high_error_proba",
type_cast=int,
)
# --- Real-time Filtering Method ---
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
filter_config={"target_zones": ["critical_hotspot"]},
)
critical_data_indices = []
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
for idx, flag in zip(indices, filter_flags):
if flag is True:
critical_data_indices.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetectionDataFilter
# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # Object Detection model
pool_dataloader = ... # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- Batch Post-Processing Method ---
data_filter = ObjectDetectionDataFilter(
model,
model_name="my_od_model",
version_name="v1",
result_name=result.result_name,
)
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
# --- Real-time Filtering Method ---
data_filter = ObjectDetectionDataFilter(
model,
model_name="my_od_model",
version_name="v1",
result_name=result.result_name,
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
high_error_images = []
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
for idx, flag in zip(indices, filter_flags):
if flag is True:
high_error_images.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetection3DDataFilter
# Set authentication credentials
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # 3D Object Detection model
pool_dataloader = ... # DataLoader for unlabeled pool data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- Batch Post-Processing Method ---
data_filter = ObjectDetection3DDataFilter(
model,
model_name="my_3d_od_model",
version_name="v1",
result_name=result.result_name,
)
for points, _, frame_ids in pool_dataloader:
points = points.to(device)
model_output, filter_flags = data_filter(points, input_ids=frame_ids)
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
# --- Real-time Filtering Method ---
data_filter = ObjectDetection3DDataFilter(
model,
model_name="my_3d_od_model",
version_name="v1",
result_name=result.result_name,
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
high_error_frames = []
for points, _, frame_ids in pool_dataloader:
points = points.to(device)
model_output, filter_flags = data_filter(points, input_ids=frame_ids)
for fid, flag in zip(frame_ids, filter_flags):
if flag is True:
high_error_frames.append(fid)
Active Learning Workflow
DataFiltering is designed to be used within an Active Learning loop.
# Active Learning loop
for iteration in range(n_iterations):
# 1. Train with current training data
train_model(model, train_dataloader)
# 2. Run evaluation
tracer = ClassificationTracer(model, model_name, version_name)
# ... collect data ...
result = evaluator.request_evaluation()
# 3. Query from pool data
data_filter = ClassificationDataFilter(
model, model_name, version_name, result.result_name
)
# ... pool data inference ...
queried_ids = data_filter.query(n_data=100, strategy="high_error_proba")
# 4. Add selected data to training data
train_indices += queried_ids
pool_indices = list(set(pool_indices) - set(queried_ids))
See Active Learning Use Case for detailed implementation examples.
Specifying n_epoch
When multiple training runs have been performed with the same version_name, explicitly specifying n_epoch is effective for avoiding data conflicts.
Specifying n_epoch at DataFilter Initialization
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
n_epoch=5, # Use data from a specific epoch
)
Why Specifying n_epoch is Important
Since n_epoch="latest" determines the latest epoch based on timestamp, if you restart epochs from 1 with the same version_name, unintended data may be used.
# Problematic case
# If you trained with epoch 0-9 in the past, then retrained with epoch 0-4
# latest points to epoch=4 (not the past epoch=9)
# Solution 1: Explicitly specify n_epoch
data_filter = ClassificationDataFilter(..., n_epoch=9)
# Solution 2: Use a new version_name
data_filter = ClassificationDataFilter(..., version_name="v2")
Recommended Solutions
- Use a new
version_namewhen starting new training - Explicitly specify
n_epochwhen you want to use data from a specific epoch - Specify
result_nameto ensure data associated with the evaluation result is used
See model_name / version_name for details.
Next Steps
- Logging - Inference log collection during operation