DataFiltering
このガイドでは、MLdebugger SDKを使用してエラーパターンに基づくデータフィルタリングを行う方法を説明します。
概要
DataFilteringフローは、評価結果に基づいてプールデータから効率的にデータを選択するためのワークフローです。 2つの方式をサポートしています:
- バッチ後処理方式: 全データを推論後、
query()メソッドでまとめて選択 - リアルタイムフィルタリング方式: 推論時にリアルタイムでフィルタ条件を判定
STEP 1で使用するDataFilterクラスはタスクによって異なります:
| タスク | DataFilterクラス |
|---|---|
| Classification | ClassificationDataFilter |
| Object Detection | ObjectDetectionDataFilter |
| 3D Object Detection | ObjectDetection3DDataFilter |
前提条件: Tracing + Evaluation が完了し、result_name が取得済みであること。
STEP 1: DataFilterの初期化
バッチ後処理方式(基本)
from ml_debugger.data_filter import ClassificationDataFilter
data_filter = ClassificationDataFilter(
model, # 評価対象のモデル
model_name="resnet18", # Tracerで使用したmodel_name
version_name="v1", # Tracerで使用したversion_name
result_name="resnet18_v1_classification_v1_20251219", # 評価結果のresult_name
)
リアルタイムフィルタリング方式
推論時にリアルタイムでデータをフィルタリングしたい場合は、filter_configパラメータを使用します。
data_filter = ClassificationDataFilter(
model,
model_name="resnet18",
version_name="v1",
result_name="resnet18_v1_classification_v1_20251219",
filter_config={"target_zones": ["critical_hotspot"]}, # Critical Hotspotゾーンを対象
)
target_zonesの選択
critical_hotspot: モデルが高い確信度で誤った予測をするデータ。品質改善に最も重要。hotspot: 一般的なエラーゾーン。幅広いエラーパターンを収集したい場合に使用。stable_coverage: モデルが高い確信度で正しく予測するデータ。学習寄与度が低いデータの間引きに使用できます。- 複数指定可能:
{"target_zones": ["critical_hotspot", "hotspot"]}(OR条件)
利用可能なゾーンの詳細や条件は、GUIのHeatmap画面で確認できます。
バッチ後処理方式(基本)
from ml_debugger.data_filter import ObjectDetectionDataFilter
data_filter = ObjectDetectionDataFilter(
model, # Object Detectionモデル
model_name="faster_rcnn", # Tracerで使用したmodel_name
version_name="v1", # Tracerで使用したversion_name
result_name="faster_rcnn_v1_od_v1_20251219", # 評価結果のresult_name
)
リアルタイムフィルタリング方式
Object Detectionでは、filter_config に BBoxStrategy 辞書を渡します。
詳細は STEP 3 の Object Detection タブを参照してください。
data_filter = ObjectDetectionDataFilter(
model,
model_name="faster_rcnn",
version_name="v1",
result_name="faster_rcnn_v1_od_v1_20251219",
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
バッチ後処理方式(基本)
from ml_debugger.data_filter import ObjectDetection3DDataFilter
data_filter = ObjectDetection3DDataFilter(
model, # 3D Object Detectionモデル
model_name="centerpoint", # Tracerで使用したmodel_name
version_name="v1", # Tracerで使用したversion_name
result_name="centerpoint_v1_od3d_v1_20251219", # 評価結果のresult_name
)
リアルタイムフィルタリング方式
2D Object Detectionと同様に、filter_config に BBoxStrategy 辞書を渡します。
詳細は STEP 3 の Object Detection タブを参照してください。
data_filter = ObjectDetection3DDataFilter(
model,
model_name="centerpoint",
version_name="v1",
result_name="centerpoint_v1_od3d_v1_20251219",
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
result_nameの取得
result_nameはEvaluator.request_evaluation()の戻り値であるResultオブジェクトから取得できます。
result = evaluator.request_evaluation()
result_name = result.result_name
STEP 2: プールデータの推論
バッチ後処理方式
ラベルなしのプールデータに対して推論を実行し、内部特徴量を収集します。
filter_configが未設定の場合、filter_flagsは全てNoneになります。
import torch
for input_data, _, ids in pool_dataloader:
input_data = input_data.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(
input_data, # 入力データ
input_ids=ids, # データの識別子
)
# filter_configが未設定の場合、filter_flagsは全てNone
ラベルは不要
DataFilterでは正解ラベルは不要です。プールデータの推論結果と内部特徴量のみを収集します。
3D Object Detectionの入力形式
3D Object Detectionでは、入力データにdict形式(points, img 等)を使用します。
詳細は Tracing + Evaluation の3D Object Detectionタブを参照してください。
リアルタイムフィルタリング方式
filter_configを指定した場合、各データがフィルタ条件にマッチしたかを示すfilter_flagsを使って、リアルタイムで分岐処理ができます。
import torch
new_data_indices = [] # フィルタ条件にマッチしたデータのインデックス
for input_data, _, ids in pool_dataloader:
input_data = input_data.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(
input_data,
input_ids=ids,
)
# filter_flagsを使って分岐処理
for id, flag in zip(ids, filter_flags):
if flag is True:
# フィルタ条件にマッチ → データセットに追加する処理へ
new_data_indices.append(id)
print(f"Total {len(new_data_indices)} data points matched filter condition")
filter_flagsの値
True: データがフィルタ条件にマッチ(抽出対象)False: データがフィルタ条件にマッチしないNone:filter_configが未設定
STEP 3: データのクエリ(バッチ後処理方式)
収集したデータから、指定した戦略に基づいてデータを選択します。
ソート戦略
データを指定の指標で並び替え、上位N件を取得します。
queried_ids = data_filter.query(
n_data=50, # 取得するデータ数
strategy="high_error_proba", # エラー確率が高い順
)
| 戦略 | 説明 |
|---|---|
"high_error_proba" |
エラー確率が高いデータを優先選択(Hard Example Mining) |
"low_error_proba" |
エラー確率が低いデータを優先選択 |
フィルタ戦略
条件にマッチするデータを最大N件取得します。辞書形式で条件を指定します。
# ゾーン指定: Hotspotゾーンのデータを最大50件取得
queried_ids = data_filter.query(
n_data=50,
strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)
# 閾値指定: エラー確率が0.8以上のデータを最大50件取得
queried_ids = data_filter.query(
n_data=50,
strategy={"conditions": [{"error_proba": ">=0.8"}]},
)
# 複合条件: 複数の閾値条件(AND)
queried_ids = data_filter.query(
n_data=50,
strategy={"conditions": [{"error_proba": ">=0.5", "score": "<0.7"}]},
)
ゾーンと閾値条件の使い分け
- ゾーン指定: エラーパターンに基づく直感的な選択
- 閾値指定: 細かな条件でカスタム選択したい場合に使用
- 両方を指定した場合、OR条件として評価されます
利用可能なゾーン
| ゾーン名 | Issue Category | 説明 |
|---|---|---|
stable_coverage |
Highly Stable | 高品質な出力が可能で内部特徴量が安定 |
operational_coverage |
Stable | 運用上許容可能で内部特徴量が安定 |
hotspot |
Unstable | 内部特徴量が不安定で予測が揺らぎやすい |
recessive_hotspot |
Under-Confidence | モデルの確信度が低くエラーを予測しやすい |
critical_hotspot |
Over-Confidence | エラー確率が高いが高確信度で予測 |
aleatoric_hotspot |
Outlier | 特徴量が不十分でモデルが学習困難 |
Object Detection(2D / 3D共通)では、strategy にソート戦略または BBoxStrategy 辞書を指定します。
BBox単位のエラー確率を画像(フレーム)単位に集約する条件をカスタマイズできます。
ソート戦略
queried_ids = data_filter.query(
n_data=50,
strategy="high_error_proba", # エラー確率が高い順
)
| 戦略 | 説明 |
|---|---|
"high_error_proba" |
エラー確率が高いデータを優先選択 |
"low_error_proba" |
エラー確率が低いデータを優先選択 |
BBoxStrategy辞書による戦略
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba", # 集約対象のエラー列
"top_n": 3, # 上位N個のBBoxを選択
"aggregation": "mean", # 集約方法(mean / median)
},
)
| キー | 型 | デフォルト | 説明 |
|---|---|---|---|
target_column |
str |
"error_proba" |
集約対象のエラー確率列 |
top_n |
Optional[int] |
None |
上位N個のBBoxを選択(Noneで全BBox) |
aggregation |
str |
"mean" |
集約方法("mean" / "median") |
weight_column |
Optional[str] |
None |
重み付き集約に使用する列 |
bbox_filter |
Optional[dict] |
None |
BBox単位のフィルタ条件 |
sort |
str |
"desc" |
ソート順("asc" / "desc") |
ソート戦略の互換性
"high_error_proba" / "low_error_proba" といった文字列戦略は、Classificationと同様にObject Detectionでも使用できます。
BBoxが検出されない画像について
v0.3.1時点で提供している機能では、NMS設定によりBBoxが1つも出力されない画像はクエリの対象外となります。
そのような画像が大量に存在する場合、n_data で指定した数よりも少ないデータが返されることがあります。
オプションパラメータ
queried_ids = data_filter.query(
n_data=50,
strategy="high_error_proba",
dataset_type="pool", # フィルタリング対象のdataset_type
type_cast=int, # input_idの型変換
)
完全なサンプルコード
import os
import torch
from ml_debugger.data_filter import ClassificationDataFilter
# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # 学習済みモデル
pool_dataloader = ... # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- バッチ後処理方式 ---
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
)
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
queried_ids = data_filter.query(
n_data=100,
strategy="high_error_proba",
type_cast=int,
)
# --- リアルタイムフィルタリング方式 ---
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
filter_config={"target_zones": ["critical_hotspot"]},
)
critical_data_indices = []
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
for idx, flag in zip(indices, filter_flags):
if flag is True:
critical_data_indices.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetectionDataFilter
# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # Object Detectionモデル
pool_dataloader = ... # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- バッチ後処理方式 ---
data_filter = ObjectDetectionDataFilter(
model,
model_name="my_od_model",
version_name="v1",
result_name=result.result_name,
)
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
# --- リアルタイムフィルタリング方式 ---
data_filter = ObjectDetectionDataFilter(
model,
model_name="my_od_model",
version_name="v1",
result_name=result.result_name,
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
high_error_images = []
for image, _, indices in pool_dataloader:
image = image.to(device)
with torch.no_grad():
model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
for idx, flag in zip(indices, filter_flags):
if flag is True:
high_error_images.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetection3DDataFilter
# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"
model = ... # 3D Object Detectionモデル
pool_dataloader = ... # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# --- バッチ後処理方式 ---
data_filter = ObjectDetection3DDataFilter(
model,
model_name="my_3d_od_model",
version_name="v1",
result_name=result.result_name,
)
for points, _, frame_ids in pool_dataloader:
points = points.to(device)
model_output, filter_flags = data_filter(points, input_ids=frame_ids)
queried_ids = data_filter.query(
n_data=50,
strategy={
"target_column": "error_proba",
"top_n": 3,
"aggregation": "mean",
},
)
# --- リアルタイムフィルタリング方式 ---
data_filter = ObjectDetection3DDataFilter(
model,
model_name="my_3d_od_model",
version_name="v1",
result_name=result.result_name,
filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)
high_error_frames = []
for points, _, frame_ids in pool_dataloader:
points = points.to(device)
model_output, filter_flags = data_filter(points, input_ids=frame_ids)
for fid, flag in zip(frame_ids, filter_flags):
if flag is True:
high_error_frames.append(fid)
Active Learningワークフロー
DataFilteringは、Active Learningのループ内で使用することを想定しています。
# Active Learning ループ
for iteration in range(n_iterations):
# 1. 現在のトレーニングデータで学習
train_model(model, train_dataloader)
# 2. 評価実行
tracer = ClassificationTracer(model, model_name, version_name)
# ... データ収集 ...
result = evaluator.request_evaluation()
# 3. プールデータからクエリ
data_filter = ClassificationDataFilter(
model, model_name, version_name, result.result_name
)
# ... プールデータ推論 ...
queried_ids = data_filter.query(n_data=100, strategy="high_error_proba")
# 4. 選択されたデータをトレーニングデータに追加
train_indices += queried_ids
pool_indices = list(set(pool_indices) - set(queried_ids))
詳細な実装例は Active Learning ユースケース を参照してください。
n_epochの指定
同一version_nameで複数回の学習を行った場合、データのコンフリクトを避けるためにn_epochを明示的に指定することが有効です。
DataFilter初期化時のn_epoch指定
data_filter = ClassificationDataFilter(
model,
model_name="my_model",
version_name="v1",
result_name=result.result_name,
n_epoch=5, # 特定のエポックのデータを使用
)
なぜn_epochの指定が重要か
n_epoch="latest"はtimestampに基づいて最新エポックを判定するため、同じversion_nameでエポックを1からやり直した場合、意図しないデータが使用される可能性があります。
# 問題のあるケース
# 過去にepoch 0-9で学習し、その後epoch 0-4で再学習した場合
# latest は epoch=4 を指す(過去のepoch=9ではなく)
# 解決策1: n_epochを明示的に指定
data_filter = ClassificationDataFilter(..., n_epoch=9)
# 解決策2: 新しいversion_nameを使用
data_filter = ClassificationDataFilter(..., version_name="v2")
推奨される対処法
- 新しい学習を開始する場合は、新しい
version_nameを使用 - 特定のエポックのデータを使用したい場合は、
n_epochを明示的に指定 result_nameを指定することで、評価結果に紐づいたデータを確実に使用
詳細はmodel_name / version_nameを参照してください。
次のステップ
- Logging - 運用時の推論ログ収集