Skip to content

Active Learning

This guide explains how to implement Active Learning using MLdebugger's DataFilter.

Prerequisites

For basic DataFilter operations, see Getting Started - DataFiltering. For the concept of data selection based on Issue Category, see Data Curation.

Overview

Active Learning is a technique that improves model performance with lower labeling costs by selectively labeling the most informative data from an unlabeled data pool.

MLdebugger's ClassificationDataFilter / ObjectDetectionDataFilter / ObjectDetection3DDataFilter achieve effective data selection based on analysis of internal features and error probability.

Active Learning Workflow

Experiment Setup

The following is an example. Adjust according to your task, dataset, and model.

# Experiment parameters
n_data_base = 3000      # Initial training data count
n_epochs_base = 5       # Initial training epochs
n_iters = 40            # Active Learning iterations
n_query = 100           # Data to add per iteration
n_epochs = 5            # Training epochs per iteration

Complete Workflow

import torch
from ml_debugger.training import ClassificationTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ClassificationDataFilter

model = Net()
model_name = "my_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ClassificationTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for images, labels, indices in train_dataloader:
            outputs = tracer(
                images.to(device),
                labels.to(device),
                input_ids=[f"train_{i}" for i in indices],
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ClassificationDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for images, _, indices in pool_dataloader:
        with torch.no_grad():
            data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetectionTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetectionDataFilter

model = ...  # Object Detection model
model_name = "my_od_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ObjectDetectionTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for images, targets, indices in train_dataloader:
            outputs = tracer(
                images.to(device),
                targets,
                input_ids=[f"train_{i}" for i in indices],
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = compute_loss(outputs, targets)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ObjectDetectionDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for images, _, indices in pool_dataloader:
        with torch.no_grad():
            data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetection3DTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetection3DDataFilter

model = ...  # 3D Object Detection model
model_name = "my_3d_od_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ObjectDetection3DTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for points, targets, frame_ids in train_dataloader:
            outputs = tracer(
                points.to(device),
                targets,
                input_ids=frame_ids,
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = compute_loss(outputs, targets)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ObjectDetection3DDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for points, _, frame_ids in pool_dataloader:
        data_filter(points.to(device), input_ids=frame_ids)

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))

Data Selection Strategies

Sort Strategy (All Tasks)

Strategy Description Use Case
"high_error_proba" Prioritize data with high error probability Hard Example Mining
"low_error_proba" Prioritize data with low error probability Collecting reliable data

Filter Strategy

Select data by Issue Category zone or threshold conditions.

# Select data from Hotspot zones
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)

# Custom threshold conditions
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={"conditions": [{"error_proba": ">=0.8"}]},
)

Use a BBoxStrategy dict to specify how per-bbox error probabilities are aggregated to image-level scores.

# Aggregate top-3 bboxes by detection error probability
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

Sort strategies ("high_error_proba", etc.) also work the same as Classification.

See Getting Started - DataFiltering for details.

Best Practices

1. Appropriate Evaluation Frequency

Evaluating every epoch is computationally expensive, so evaluation every few epochs is recommended.

if epoch % eval_interval == 0:
    result = evaluator.request_evaluation(n_epoch=epoch)

2. Dataset Management

Ensure removal from pool to prevent duplication of selected data.

pool_indices = list(set(pool_indices) - set(queried_indices))

3. Early Stopping

Consider early termination when performance converges.

if current_metric - previous_metric < threshold:
    print("Performance converged, stopping early")
    break

4. Saving Logs

Save logs to track experiment results.

logs.append({
    "iteration": iteration,
    "train_size": len(train_indices),
    "metrics": result.metrics_summary(),
})

Real-time Filtering vs Batch Processing

For Active Learning, the batch post-processing method using query() is recommended. For data collection optimization in production environments (Data Curation), real-time filtering using filter_config is more suitable.

Method Use Case Details
Batch Post-Processing (query()) Active Learning This page
Real-time Filtering Data Curation Data Curation

See Getting Started - DataFiltering for details.

Next Steps