コンテンツにスキップ

Active Learning

このガイドでは、MLdebuggerのDataFilterを使用したActive Learningの実装方法を説明します。

前提知識

DataFilterの基本操作は Getting Started - DataFiltering、Issue Categoryに基づくデータ選択の概念は Data Curation を参照してください。

概要

Active Learningは、ラベルなしデータプールから最も有益なデータを選択的にラベル付けすることで、少ないラベル付けコストでモデルの性能を向上させる手法です。

MLdebuggerの ClassificationDataFilter / ObjectDetectionDataFilter / ObjectDetection3DDataFilter は、内部特徴量とエラー確率の分析に基づいて効果的なデータ選択を実現します。

Active Learningワークフロー

実験設定

以下は一例です。使用するタスク、データセット、モデルに応じて調整してください。

# 実験パラメータ
n_data_base = 3000      # 初期トレーニングデータ数
n_epochs_base = 5       # 初期学習エポック数
n_iters = 40            # Active Learningイテレーション数
n_query = 100           # 各イテレーションで追加するデータ数
n_epochs = 5            # 各イテレーションの学習エポック数

完全なワークフロー

import torch
from ml_debugger.training import ClassificationTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ClassificationDataFilter

model = Net()
model_name = "my_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ClassificationTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for images, labels, indices in train_dataloader:
            outputs = tracer(
                images.to(device),
                labels.to(device),
                input_ids=[f"train_{i}" for i in indices],
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ClassificationDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for images, _, indices in pool_dataloader:
        with torch.no_grad():
            data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetectionTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetectionDataFilter

model = ...  # Object Detectionモデル
model_name = "my_od_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ObjectDetectionTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for images, targets, indices in train_dataloader:
            outputs = tracer(
                images.to(device),
                targets,
                input_ids=[f"train_{i}" for i in indices],
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = compute_loss(outputs, targets)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ObjectDetectionDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for images, _, indices in pool_dataloader:
        with torch.no_grad():
            data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetection3DTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetection3DDataFilter

model = ...  # 3D Object Detectionモデル
model_name = "my_3d_od_model"
version_name = "active_learning_v1"

train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()

for iteration in range(n_iters):
    # === Step 1: Tracing + Training ===
    tracer = ObjectDetection3DTracer(model, model_name, version_name)

    for epoch in range(n_epochs):
        model.train()
        for points, targets, frame_ids in train_dataloader:
            outputs = tracer(
                points.to(device),
                targets,
                input_ids=frame_ids,
                dataset_type="train",
                n_epoch=epoch,
            )
            loss = compute_loss(outputs, targets)
            loss.backward()
            optimizer.step()

    tracer.wait_for_save(n_epoch=n_epochs - 1)

    # === Step 2: Evaluation ===
    evaluator = Evaluator(model_name, version_name)
    result = evaluator.request_evaluation(n_epoch="latest")

    # === Step 3: Data Selection ===
    data_filter = ObjectDetection3DDataFilter(
        model, model_name, version_name,
        result_name=result.result_name,
    )

    model.eval()
    for points, _, frame_ids in pool_dataloader:
        data_filter(points.to(device), input_ids=frame_ids)

    queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")

    # === Step 4: Update Dataset ===
    queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
    train_indices += queried_indices
    pool_indices = list(set(pool_indices) - set(queried_indices))

データ選択戦略

ソート戦略(全タスク共通)

戦略 説明 用途
"high_error_proba" エラー確率が高いデータを優先選択 Hard Example Mining
"low_error_proba" エラー確率が低いデータを優先選択 信頼性の高いデータ収集

フィルタ戦略

Issue Categoryに基づくゾーン指定や閾値条件でデータを選択します。

# Hotspotゾーンのデータを選択
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)

# カスタム閾値条件
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={"conditions": [{"error_proba": ">=0.8"}]},
)

BBoxStrategy 辞書を使用して、BBox単位のエラー確率を画像単位に集約する条件を指定します。

# 検出エラー確率が高いBBox上位3つの平均で集約
queried_ids = data_filter.query(
    n_data=n_query,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

ソート戦略("high_error_proba" 等)もClassificationと同様に使用できます。

詳細は Getting Started - DataFiltering を参照してください。

ベストプラクティス

1. 適切な評価頻度

毎エポック評価すると計算コストが高いため、数エポックごとに評価することを推奨します。

if epoch % eval_interval == 0:
    result = evaluator.request_evaluation(n_epoch=epoch)

2. データセット管理

選択済みデータの重複を防ぐため、プールからの削除を確実に行います。

pool_indices = list(set(pool_indices) - set(queried_indices))

3. 早期停止

性能が収束したら早期終了を検討します。

if current_metric - previous_metric < threshold:
    print("Performance converged, stopping early")
    break

4. ログの保存

実験結果を追跡するためにログを保存します。

logs.append({
    "iteration": iteration,
    "train_size": len(train_indices),
    "metrics": result.metrics_summary(),
})

リアルタイムフィルタリングとの使い分け

Active Learningでは、query() メソッドによるバッチ後処理方式が推奨されます。 運用環境でのデータ収集最適化(Data Curation)では、filter_config を使用したリアルタイムフィルタリング方式が適しています。

方式 用途 詳細
バッチ後処理(query() Active Learning 本ページ
リアルタイムフィルタリング Data Curation Data Curation

詳細は Getting Started - DataFiltering を参照してください。

次のステップ