Active Learning
This guide explains how to implement Active Learning using MLdebugger's DataFilter.
Prerequisites
For basic DataFilter operations, see Getting Started - DataFiltering. For the concept of data selection based on Issue Category, see Data Curation.
Overview
Active Learning is a technique that improves model performance with lower labeling costs by selectively labeling the most informative data from an unlabeled data pool.
MLdebugger's ClassificationDataFilter / ObjectDetectionDataFilter / ObjectDetection3DDataFilter achieve effective data selection based on analysis of internal features and error probability.
Active Learning Workflow
Experiment Setup
The following is an example. Adjust according to your task, dataset, and model.
# Experiment parameters
n_data_base = 3000 # Initial training data count
n_epochs_base = 5 # Initial training epochs
n_iters = 40 # Active Learning iterations
n_query = 100 # Data to add per iteration
n_epochs = 5 # Training epochs per iteration
Complete Workflow
import torch
from ml_debugger.training import ClassificationTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ClassificationDataFilter
model = Net()
model_name = "my_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ClassificationTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for images, labels, indices in train_dataloader:
outputs = tracer(
images.to(device),
labels.to(device),
input_ids=[f"train_{i}" for i in indices],
dataset_type="train",
n_epoch=epoch,
)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ClassificationDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for images, _, indices in pool_dataloader:
with torch.no_grad():
data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetectionTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetectionDataFilter
model = ... # Object Detection model
model_name = "my_od_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ObjectDetectionTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for images, targets, indices in train_dataloader:
outputs = tracer(
images.to(device),
targets,
input_ids=[f"train_{i}" for i in indices],
dataset_type="train",
n_epoch=epoch,
)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ObjectDetectionDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for images, _, indices in pool_dataloader:
with torch.no_grad():
data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetection3DTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetection3DDataFilter
model = ... # 3D Object Detection model
model_name = "my_3d_od_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ObjectDetection3DTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for points, targets, frame_ids in train_dataloader:
outputs = tracer(
points.to(device),
targets,
input_ids=frame_ids,
dataset_type="train",
n_epoch=epoch,
)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ObjectDetection3DDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for points, _, frame_ids in pool_dataloader:
data_filter(points.to(device), input_ids=frame_ids)
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
Data Selection Strategies
Sort Strategy (All Tasks)
| Strategy | Description | Use Case |
|---|---|---|
"high_error_proba" |
Prioritize data with high error probability | Hard Example Mining |
"low_error_proba" |
Prioritize data with low error probability | Collecting reliable data |
Filter Strategy
Select data by Issue Category zone or threshold conditions.
# Select data from Hotspot zones
queried_ids = data_filter.query(
n_data=n_query,
strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)
# Custom threshold conditions
queried_ids = data_filter.query(
n_data=n_query,
strategy={"conditions": [{"error_proba": ">=0.8"}]},
)
Use a BBoxStrategy dict to specify how per-bbox error probabilities are aggregated to image-level scores.
# Aggregate top-3 bboxes by detection error probability
queried_ids = data_filter.query(
n_data=n_query,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
Sort strategies ("high_error_proba", etc.) also work the same as Classification.
See Getting Started - DataFiltering for details.
Best Practices
1. Appropriate Evaluation Frequency
Evaluating every epoch is computationally expensive, so evaluation every few epochs is recommended.
if epoch % eval_interval == 0:
result = evaluator.request_evaluation(n_epoch=epoch)
2. Dataset Management
Ensure removal from pool to prevent duplication of selected data.
pool_indices = list(set(pool_indices) - set(queried_indices))
3. Early Stopping
Consider early termination when performance converges.
if current_metric - previous_metric < threshold:
print("Performance converged, stopping early")
break
4. Saving Logs
Save logs to track experiment results.
logs.append({
"iteration": iteration,
"train_size": len(train_indices),
"metrics": result.metrics_summary(),
})
Real-time Filtering vs Batch Processing
For Active Learning, the batch post-processing method using query() is recommended.
For data collection optimization in production environments (Data Curation), real-time filtering using filter_config is more suitable.
| Method | Use Case | Details |
|---|---|---|
Batch Post-Processing (query()) |
Active Learning | This page |
| Real-time Filtering | Data Curation | Data Curation |
See Getting Started - DataFiltering for details.
Next Steps
- Data Curation - Data curation concepts
- Getting Started - DataFiltering - Basic DataFilter operations