Skip to content

ClassificationDataFilter Class

ml_debugger.data_filter.classification.classification_torch_data_filter.ClassificationTorchDataFilter

Bases: CommonDataFilter, ClassificationTorchTracer

DataFilter for classification tasks using PyTorch models.

__init__(model, model_name, version_name, result_name=None, n_epoch='latest', filter_config=None, target_layers=None, additional_fields=None, auto_sync=False, force_table_recreate=False, api_endpoint=None, api_key=None)

Initialize classification data filter.

Parameters:

Name Type Description Default

model

Module

PyTorch model to trace.

required

model_name

str

Name of the ML model.

required

version_name

str

Version identifier for the ML model.

required

result_name

Optional[str]

The name of the existing evaluation result to retrieve.

None

n_epoch

Union[str, Optional[int]]

Filter option for n_epoch value.

'latest'

filter_config

Optional[Union[FilterConfig, Dict[str, Any]]]

Filter configuration for data filtering. Can be a FilterConfig instance or a dict that will be validated and converted to FilterConfig. Example: {"target_zones": ["hotspot"]} or {"conditions": [{"error_proba": ">=0.8"}]}

None

target_layers

Optional[Dict[str, str]]

Mapping of layer aliases to module paths.

None

additional_fields

Optional[List[dict]]

Extra fields for database schema.

None

auto_sync

bool

Enable backgroup syncing of logged data.

False

force_table_recreate

bool

Whether to drop and recreate existing tables.

False

api_endpoint

Optional[str]

URL of the service API for data upload.

None

api_key

Optional[str]

API key for authenticating with the service.

None

__call__(model_input, input_ids, dataset_type='pool', **kwargs)

Invoke the tracer on a single inference, recording I/O data.

Parameters:

Name Type Description Default

model_input

Any

Input data for the model inference.

required

input_ids

List[str]

Identifiers of each input data.

required

dataset_type

str

Identifier of input dataset. (e.g. 'pool')

'pool'

**kwargs

Any

Additional keyword arguments for parsing and saving I/O data. (will be passed to self._parse_and_save_io_data)

{}

Returns:

Type Description
Tuple[Any, List[Optional[bool]]]

Tuple[Any, List[Optional[bool]]]: - model_output: Raw model output - filter_flags: List of booleans indicating if each input matches filter condition (True = matches filter / should be extracted, False = does not match, None = no filter configured)

Raises:

Type Description
Any

Propagates exceptions from parsing and saving operations.

get_hooked_features(layer_name)

Retrieve the captured output for a given layer alias.

Parameters:

Name Type Description Default

layer_name

str

Alias of the layer whose activation was captured.

required

Returns:

Name Type Description
Any Any

Activation data stored for the specified layer.

Raises:

Type Description
KeyError

If no activation has been captured for layer_name.

export(output_path=None)

Export extracted features into a ZIP archive.

Uses the internal n_epoch resolved during validator setup to filter records, consistent with upload() and wait_for_save().

Parameters:

Name Type Description Default

output_path

Optional[str]

Path or directory for saving the ZIP file. If no .zip extension, the default filename is appended. Defaults to cwd.

None

Returns:

Type Description
Optional[Path]

Path to the created ZIP file, or None on non-primary distributed ranks.

wait_for_save(interval=3)

upload()

query(n_data, strategy, dataset_type='pool', type_cast=None)

Sort and query dataset based on strategy.

Parameters:

Name Type Description Default

n_data

int

Maximum number of data to query.

required

strategy

Union[str, FilterConfig, Dict[str, Any]]

Query strategy. - 'high_error_proba': Sort by error probability descending - 'low_error_proba': Sort by error probability ascending - FilterConfig or dict: Filter by conditions and return matching data

required

dataset_type

str

Filter of input dataset. (e.g. 'pool')

'pool'

type_cast

Optional[type]

Type for casting input_id. (e.g. int)

None

Returns:

Type Description
List[str]

List[str]: List of input_id of queried dataset.

Raises:

Type Description
ValueError

If strategy is invalid.

Examples:

>>> # Sort by high error probability
>>> ids = data_filter.query(n_data=10, strategy='high_error_proba')
>>>
>>> # Filter by zone
>>> ids = data_filter.query(n_data=10, strategy={'target_zones': ['hotspot']})
>>>
>>> # Filter by conditions
>>> ids = data_filter.query(n_data=10, strategy={'conditions': [{'error_proba': '>=0.8'}]})