Active Learning
このガイドでは、MLdebuggerのDataFilterを使用したActive Learningの実装方法を説明します。
前提知識
DataFilterの基本操作は Getting Started - DataFiltering、Issue Categoryに基づくデータ選択の概念は Data Curation を参照してください。
概要
Active Learningは、ラベルなしデータプールから最も有益なデータを選択的にラベル付けすることで、少ないラベル付けコストでモデルの性能を向上させる手法です。
MLdebuggerの ClassificationDataFilter / ObjectDetectionDataFilter / ObjectDetection3DDataFilter は、内部特徴量とエラー確率の分析に基づいて効果的なデータ選択を実現します。
Active Learningワークフロー
実験設定
以下は一例です。使用するタスク、データセット、モデルに応じて調整してください。
# 実験パラメータ
n_data_base = 3000 # 初期トレーニングデータ数
n_epochs_base = 5 # 初期学習エポック数
n_iters = 40 # Active Learningイテレーション数
n_query = 100 # 各イテレーションで追加するデータ数
n_epochs = 5 # 各イテレーションの学習エポック数
完全なワークフロー
import torch
from ml_debugger.training import ClassificationTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ClassificationDataFilter
model = Net()
model_name = "my_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ClassificationTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for images, labels, indices in train_dataloader:
outputs = tracer(
images.to(device),
labels.to(device),
input_ids=[f"train_{i}" for i in indices],
dataset_type="train",
n_epoch=epoch,
)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ClassificationDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for images, _, indices in pool_dataloader:
with torch.no_grad():
data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetectionTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetectionDataFilter
model = ... # Object Detectionモデル
model_name = "my_od_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ObjectDetectionTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for images, targets, indices in train_dataloader:
outputs = tracer(
images.to(device),
targets,
input_ids=[f"train_{i}" for i in indices],
dataset_type="train",
n_epoch=epoch,
)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ObjectDetectionDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for images, _, indices in pool_dataloader:
with torch.no_grad():
data_filter(images.to(device), input_ids=[f"pool_{i}" for i in indices])
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
import torch
from ml_debugger.training import ObjectDetection3DTracer
from ml_debugger.evaluator import Evaluator
from ml_debugger.data_filter import ObjectDetection3DDataFilter
model = ... # 3D Object Detectionモデル
model_name = "my_3d_od_model"
version_name = "active_learning_v1"
train_indices = initial_train_indices.copy()
pool_indices = initial_pool_indices.copy()
for iteration in range(n_iters):
# === Step 1: Tracing + Training ===
tracer = ObjectDetection3DTracer(model, model_name, version_name)
for epoch in range(n_epochs):
model.train()
for points, targets, frame_ids in train_dataloader:
outputs = tracer(
points.to(device),
targets,
input_ids=frame_ids,
dataset_type="train",
n_epoch=epoch,
)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
tracer.wait_for_save(n_epoch=n_epochs - 1)
# === Step 2: Evaluation ===
evaluator = Evaluator(model_name, version_name)
result = evaluator.request_evaluation(n_epoch="latest")
# === Step 3: Data Selection ===
data_filter = ObjectDetection3DDataFilter(
model, model_name, version_name,
result_name=result.result_name,
)
model.eval()
for points, _, frame_ids in pool_dataloader:
data_filter(points.to(device), input_ids=frame_ids)
queried_ids = data_filter.query(n_data=n_query, strategy="high_error_proba")
# === Step 4: Update Dataset ===
queried_indices = [int(id.replace("pool_", "")) for id in queried_ids]
train_indices += queried_indices
pool_indices = list(set(pool_indices) - set(queried_indices))
データ選択戦略
ソート戦略(全タスク共通)
| 戦略 | 説明 | 用途 |
|---|---|---|
"high_error_proba" |
エラー確率が高いデータを優先選択 | Hard Example Mining |
"low_error_proba" |
エラー確率が低いデータを優先選択 | 信頼性の高いデータ収集 |
フィルタ戦略
Issue Categoryに基づくゾーン指定や閾値条件でデータを選択します。
# Hotspotゾーンのデータを選択
queried_ids = data_filter.query(
n_data=n_query,
strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)
# カスタム閾値条件
queried_ids = data_filter.query(
n_data=n_query,
strategy={"conditions": [{"error_proba": ">=0.8"}]},
)
BBoxStrategy 辞書を使用して、BBox単位のエラー確率を画像単位に集約する条件を指定します。
# 検出エラー確率が高いBBox上位3つの平均で集約
queried_ids = data_filter.query(
n_data=n_query,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
ソート戦略("high_error_proba" 等)もClassificationと同様に使用できます。
詳細は Getting Started - DataFiltering を参照してください。
ベストプラクティス
1. 適切な評価頻度
毎エポック評価すると計算コストが高いため、数エポックごとに評価することを推奨します。
if epoch % eval_interval == 0:
result = evaluator.request_evaluation(n_epoch=epoch)
2. データセット管理
選択済みデータの重複を防ぐため、プールからの削除を確実に行います。
pool_indices = list(set(pool_indices) - set(queried_indices))
3. 早期停止
性能が収束したら早期終了を検討します。
if current_metric - previous_metric < threshold:
print("Performance converged, stopping early")
break
4. ログの保存
実験結果を追跡するためにログを保存します。
logs.append({
"iteration": iteration,
"train_size": len(train_indices),
"metrics": result.metrics_summary(),
})
リアルタイムフィルタリングとの使い分け
Active Learningでは、query() メソッドによるバッチ後処理方式が推奨されます。
運用環境でのデータ収集最適化(Data Curation)では、filter_config を使用したリアルタイムフィルタリング方式が適しています。
| 方式 | 用途 | 詳細 |
|---|---|---|
バッチ後処理(query()) |
Active Learning | 本ページ |
| リアルタイムフィルタリング | Data Curation | Data Curation |
詳細は Getting Started - DataFiltering を参照してください。
次のステップ
- Data Curation - データキュレーションの概念
- Getting Started - DataFiltering - DataFilterの基本操作