コンテンツにスキップ

DataFiltering

このガイドでは、MLdebugger SDKを使用してエラーパターンに基づくデータフィルタリングを行う方法を説明します。

概要

DataFilteringフローは、評価結果に基づいてプールデータから効率的にデータを選択するためのワークフローです。 2つの方式をサポートしています:

  • バッチ後処理方式: 全データを推論後、query()メソッドでまとめて選択
  • リアルタイムフィルタリング方式: 推論時にリアルタイムでフィルタ条件を判定

STEP 1で使用するDataFilterクラスはタスクによって異なります:

タスク DataFilterクラス
Classification ClassificationDataFilter
Object Detection ObjectDetectionDataFilter
3D Object Detection ObjectDetection3DDataFilter

前提条件: Tracing + Evaluation が完了し、result_name が取得済みであること。

STEP 1: DataFilterの初期化

バッチ後処理方式(基本)

from ml_debugger.data_filter import ClassificationDataFilter

data_filter = ClassificationDataFilter(
    model,                                      # 評価対象のモデル
    model_name="resnet18",                      # Tracerで使用したmodel_name
    version_name="v1",                          # Tracerで使用したversion_name
    result_name="resnet18_v1_classification_v1_20251219",  # 評価結果のresult_name
)

リアルタイムフィルタリング方式

推論時にリアルタイムでデータをフィルタリングしたい場合は、filter_configパラメータを使用します。

data_filter = ClassificationDataFilter(
    model,
    model_name="resnet18",
    version_name="v1",
    result_name="resnet18_v1_classification_v1_20251219",
    filter_config={"target_zones": ["critical_hotspot"]},  # Critical Hotspotゾーンを対象
)

target_zonesの選択

  • critical_hotspot: モデルが高い確信度で誤った予測をするデータ。品質改善に最も重要。
  • hotspot: 一般的なエラーゾーン。幅広いエラーパターンを収集したい場合に使用。
  • stable_coverage: モデルが高い確信度で正しく予測するデータ。学習寄与度が低いデータの間引きに使用できます。
  • 複数指定可能: {"target_zones": ["critical_hotspot", "hotspot"]}(OR条件)

利用可能なゾーンの詳細や条件は、GUIのHeatmap画面で確認できます。

バッチ後処理方式(基本)

from ml_debugger.data_filter import ObjectDetectionDataFilter

data_filter = ObjectDetectionDataFilter(
    model,                                      # Object Detectionモデル
    model_name="faster_rcnn",                   # Tracerで使用したmodel_name
    version_name="v1",                          # Tracerで使用したversion_name
    result_name="faster_rcnn_v1_od_v1_20251219",  # 評価結果のresult_name
)

リアルタイムフィルタリング方式

Object Detectionでは、filter_configBBoxStrategy 辞書を渡します。 詳細は STEP 3 の Object Detection タブを参照してください。

data_filter = ObjectDetectionDataFilter(
    model,
    model_name="faster_rcnn",
    version_name="v1",
    result_name="faster_rcnn_v1_od_v1_20251219",
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

バッチ後処理方式(基本)

from ml_debugger.data_filter import ObjectDetection3DDataFilter

data_filter = ObjectDetection3DDataFilter(
    model,                                      # 3D Object Detectionモデル
    model_name="centerpoint",                   # Tracerで使用したmodel_name
    version_name="v1",                          # Tracerで使用したversion_name
    result_name="centerpoint_v1_od3d_v1_20251219",  # 評価結果のresult_name
)

リアルタイムフィルタリング方式

2D Object Detectionと同様に、filter_configBBoxStrategy 辞書を渡します。 詳細は STEP 3 の Object Detection タブを参照してください。

data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="centerpoint",
    version_name="v1",
    result_name="centerpoint_v1_od3d_v1_20251219",
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

result_nameの取得

result_nameEvaluator.request_evaluation()の戻り値であるResultオブジェクトから取得できます。

result = evaluator.request_evaluation()
result_name = result.result_name

STEP 2: プールデータの推論

バッチ後処理方式

ラベルなしのプールデータに対して推論を実行し、内部特徴量を収集します。 filter_configが未設定の場合、filter_flagsは全てNoneになります。

import torch

for input_data, _, ids in pool_dataloader:
    input_data = input_data.to(device)

    with torch.no_grad():
        model_output, filter_flags = data_filter(
            input_data,                             # 入力データ
            input_ids=ids,                          # データの識別子
        )
        # filter_configが未設定の場合、filter_flagsは全てNone

ラベルは不要

DataFilterでは正解ラベルは不要です。プールデータの推論結果と内部特徴量のみを収集します。

3D Object Detectionの入力形式

3D Object Detectionでは、入力データにdict形式(points, img 等)を使用します。 詳細は Tracing + Evaluation の3D Object Detectionタブを参照してください。

リアルタイムフィルタリング方式

filter_configを指定した場合、各データがフィルタ条件にマッチしたかを示すfilter_flagsを使って、リアルタイムで分岐処理ができます。

import torch

new_data_indices = []  # フィルタ条件にマッチしたデータのインデックス

for input_data, _, ids in pool_dataloader:
    input_data = input_data.to(device)

    with torch.no_grad():
        model_output, filter_flags = data_filter(
            input_data,
            input_ids=ids,
        )

    # filter_flagsを使って分岐処理
    for id, flag in zip(ids, filter_flags):
        if flag is True:
            # フィルタ条件にマッチ → データセットに追加する処理へ
            new_data_indices.append(id)

print(f"Total {len(new_data_indices)} data points matched filter condition")

filter_flagsの値

  • True: データがフィルタ条件にマッチ(抽出対象)
  • False: データがフィルタ条件にマッチしない
  • None: filter_configが未設定

STEP 3: データのクエリ(バッチ後処理方式)

収集したデータから、指定した戦略に基づいてデータを選択します。

ソート戦略

データを指定の指標で並び替え、上位N件を取得します。

queried_ids = data_filter.query(
    n_data=50,                      # 取得するデータ数
    strategy="high_error_proba",    # エラー確率が高い順
)
戦略 説明
"high_error_proba" エラー確率が高いデータを優先選択(Hard Example Mining)
"low_error_proba" エラー確率が低いデータを優先選択

フィルタ戦略

条件にマッチするデータを最大N件取得します。辞書形式で条件を指定します。

# ゾーン指定: Hotspotゾーンのデータを最大50件取得
queried_ids = data_filter.query(
    n_data=50,
    strategy={"target_zones": ["hotspot", "critical_hotspot"]},
)

# 閾値指定: エラー確率が0.8以上のデータを最大50件取得
queried_ids = data_filter.query(
    n_data=50,
    strategy={"conditions": [{"error_proba": ">=0.8"}]},
)

# 複合条件: 複数の閾値条件(AND)
queried_ids = data_filter.query(
    n_data=50,
    strategy={"conditions": [{"error_proba": ">=0.5", "score": "<0.7"}]},
)

ゾーンと閾値条件の使い分け

  • ゾーン指定: エラーパターンに基づく直感的な選択
  • 閾値指定: 細かな条件でカスタム選択したい場合に使用
  • 両方を指定した場合、OR条件として評価されます

利用可能なゾーン

ゾーン名 Issue Category 説明
stable_coverage Highly Stable 高品質な出力が可能で内部特徴量が安定
operational_coverage Stable 運用上許容可能で内部特徴量が安定
hotspot Unstable 内部特徴量が不安定で予測が揺らぎやすい
recessive_hotspot Under-Confidence モデルの確信度が低くエラーを予測しやすい
critical_hotspot Over-Confidence エラー確率が高いが高確信度で予測
aleatoric_hotspot Outlier 特徴量が不十分でモデルが学習困難

Object Detection(2D / 3D共通)では、strategy にソート戦略または BBoxStrategy 辞書を指定します。 BBox単位のエラー確率を画像(フレーム)単位に集約する条件をカスタマイズできます。

ソート戦略

queried_ids = data_filter.query(
    n_data=50,
    strategy="high_error_proba",    # エラー確率が高い順
)
戦略 説明
"high_error_proba" エラー確率が高いデータを優先選択
"low_error_proba" エラー確率が低いデータを優先選択

BBoxStrategy辞書による戦略

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",  # 集約対象のエラー列
        "top_n": 3,                          # 上位N個のBBoxを選択
        "aggregation": "mean",               # 集約方法(mean / median)
    },
)
キー デフォルト 説明
target_column str "error_proba" 集約対象のエラー確率列
top_n Optional[int] None 上位N個のBBoxを選択(Noneで全BBox)
aggregation str "mean" 集約方法("mean" / "median"
weight_column Optional[str] None 重み付き集約に使用する列
bbox_filter Optional[dict] None BBox単位のフィルタ条件
sort str "desc" ソート順("asc" / "desc"

ソート戦略の互換性

"high_error_proba" / "low_error_proba" といった文字列戦略は、Classificationと同様にObject Detectionでも使用できます。

BBoxが検出されない画像について

v0.3.1時点で提供している機能では、NMS設定によりBBoxが1つも出力されない画像はクエリの対象外となります。 そのような画像が大量に存在する場合、n_data で指定した数よりも少ないデータが返されることがあります。

オプションパラメータ

queried_ids = data_filter.query(
    n_data=50,
    strategy="high_error_proba",
    dataset_type="pool",            # フィルタリング対象のdataset_type
    type_cast=int,                  # input_idの型変換
)

完全なサンプルコード

import os
import torch
from ml_debugger.data_filter import ClassificationDataFilter

# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # 学習済みモデル
pool_dataloader = ...  # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- バッチ後処理方式 ---
data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
)

for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())

queried_ids = data_filter.query(
    n_data=100,
    strategy="high_error_proba",
    type_cast=int,
)

# --- リアルタイムフィルタリング方式 ---
data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"target_zones": ["critical_hotspot"]},
)

critical_data_indices = []
for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
    for idx, flag in zip(indices, filter_flags):
        if flag is True:
            critical_data_indices.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetectionDataFilter

# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # Object Detectionモデル
pool_dataloader = ...  # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- バッチ後処理方式 ---
data_filter = ObjectDetectionDataFilter(
    model,
    model_name="my_od_model",
    version_name="v1",
    result_name=result.result_name,
)

for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

# --- リアルタイムフィルタリング方式 ---
data_filter = ObjectDetectionDataFilter(
    model,
    model_name="my_od_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

high_error_images = []
for image, _, indices in pool_dataloader:
    image = image.to(device)
    with torch.no_grad():
        model_output, filter_flags = data_filter(image, input_ids=indices.cpu().numpy())
    for idx, flag in zip(indices, filter_flags):
        if flag is True:
            high_error_images.append(idx.item())
import os
import torch
from ml_debugger.data_filter import ObjectDetection3DDataFilter

# 認証情報設定
os.environ["MLD_API_ENDPOINT"] = "https://api.adansons.ai"
os.environ["MLD_API_KEY"] = "mldbg_*************"

model = ...  # 3D Object Detectionモデル
pool_dataloader = ...  # ラベルなしプールデータのDataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# --- バッチ後処理方式 ---
data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="my_3d_od_model",
    version_name="v1",
    result_name=result.result_name,
)

for points, _, frame_ids in pool_dataloader:
    points = points.to(device)
    model_output, filter_flags = data_filter(points, input_ids=frame_ids)

queried_ids = data_filter.query(
    n_data=50,
    strategy={
        "target_column": "error_proba",
        "top_n": 3,
        "aggregation": "mean",
    },
)

# --- リアルタイムフィルタリング方式 ---
data_filter = ObjectDetection3DDataFilter(
    model,
    model_name="my_3d_od_model",
    version_name="v1",
    result_name=result.result_name,
    filter_config={"img_error_threshold": 0.5, "aggregation": "mean"},
)

high_error_frames = []
for points, _, frame_ids in pool_dataloader:
    points = points.to(device)
    model_output, filter_flags = data_filter(points, input_ids=frame_ids)
    for fid, flag in zip(frame_ids, filter_flags):
        if flag is True:
            high_error_frames.append(fid)

Active Learningワークフロー

DataFilteringは、Active Learningのループ内で使用することを想定しています。

# Active Learning ループ
for iteration in range(n_iterations):
    # 1. 現在のトレーニングデータで学習
    train_model(model, train_dataloader)

    # 2. 評価実行
    tracer = ClassificationTracer(model, model_name, version_name)
    # ... データ収集 ...
    result = evaluator.request_evaluation()

    # 3. プールデータからクエリ
    data_filter = ClassificationDataFilter(
        model, model_name, version_name, result.result_name
    )
    # ... プールデータ推論 ...
    queried_ids = data_filter.query(n_data=100, strategy="high_error_proba")

    # 4. 選択されたデータをトレーニングデータに追加
    train_indices += queried_ids
    pool_indices = list(set(pool_indices) - set(queried_ids))

詳細な実装例は Active Learning ユースケース を参照してください。

n_epochの指定

同一version_nameで複数回の学習を行った場合、データのコンフリクトを避けるためにn_epochを明示的に指定することが有効です。

DataFilter初期化時のn_epoch指定

data_filter = ClassificationDataFilter(
    model,
    model_name="my_model",
    version_name="v1",
    result_name=result.result_name,
    n_epoch=5,  # 特定のエポックのデータを使用
)

なぜn_epochの指定が重要か

n_epoch="latest"はtimestampに基づいて最新エポックを判定するため、同じversion_nameでエポックを1からやり直した場合、意図しないデータが使用される可能性があります。

# 問題のあるケース
# 過去にepoch 0-9で学習し、その後epoch 0-4で再学習した場合
# latest は epoch=4 を指す(過去のepoch=9ではなく)

# 解決策1: n_epochを明示的に指定
data_filter = ClassificationDataFilter(..., n_epoch=9)

# 解決策2: 新しいversion_nameを使用
data_filter = ClassificationDataFilter(..., version_name="v2")

推奨される対処法

  • 新しい学習を開始する場合は、新しいversion_nameを使用
  • 特定のエポックのデータを使用したい場合は、n_epochを明示的に指定
  • result_nameを指定することで、評価結果に紐づいたデータを確実に使用

詳細はmodel_name / version_nameを参照してください。

次のステップ

  • Logging - 運用時の推論ログ収集