ClassificationDataFilter Class
ml_debugger.data_filter.classification.classification_torch_data_filter.ClassificationTorchDataFilter
Bases: CommonDataFilter, ClassificationTorchTracer
DataFilter for classification tasks using PyTorch models.
__init__(model, model_name, version_name, result_name=None, n_epoch='latest', filter_config=None, target_layers=None, additional_fields=None, auto_sync=False, force_table_recreate=False, api_endpoint=None, api_key=None)
Initialize classification data filter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Module
|
PyTorch model to trace. |
required |
|
str
|
Name of the ML model. |
required |
|
str
|
Version identifier for the ML model. |
required |
|
Optional[str]
|
The name of the existing evaluation result to retrieve. |
None
|
|
Union[str, Optional[int]]
|
Filter option for n_epoch value. |
'latest'
|
|
Optional[Union[FilterConfig, Dict[str, Any]]]
|
Filter configuration for data filtering. Can be a FilterConfig instance or a dict that will be validated and converted to FilterConfig. Example: {"target_zones": ["hotspot"]} or {"conditions": [{"error_proba": ">=0.8"}]} |
None
|
|
Optional[Dict[str, str]]
|
Mapping of layer aliases to module paths. |
None
|
|
Optional[List[dict]]
|
Extra fields for database schema. |
None
|
|
bool
|
Enable backgroup syncing of logged data. |
False
|
|
bool
|
Whether to drop and recreate existing tables. |
False
|
|
Optional[str]
|
URL of the service API for data upload. |
None
|
|
Optional[str]
|
API key for authenticating with the service. |
None
|
__call__(model_input, input_ids, dataset_type='pool', **kwargs)
Invoke the tracer on a single inference, recording I/O data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Input data for the model inference. |
required |
|
List[str]
|
Identifiers of each input data. |
required |
|
str
|
Identifier of input dataset. (e.g. 'pool') |
'pool'
|
|
Any
|
Additional keyword arguments for parsing and saving I/O data. (will be passed to |
{}
|
Returns:
| Type | Description |
|---|---|
Tuple[Any, List[Optional[bool]]]
|
Tuple[Any, List[Optional[bool]]]: - model_output: Raw model output - filter_flags: List of booleans indicating if each input matches filter condition (True = matches filter / should be extracted, False = does not match, None = no filter configured) |
Raises:
| Type | Description |
|---|---|
Any
|
Propagates exceptions from parsing and saving operations. |
get_hooked_features(layer_name)
Retrieve the captured output for a given layer alias.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Alias of the layer whose activation was captured. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Any |
Any
|
Activation data stored for the specified layer. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If no activation has been captured for |
export(output_path=None)
Export extracted features into a ZIP archive.
Uses the internal n_epoch resolved during validator setup to
filter records, consistent with upload() and wait_for_save().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Optional[str]
|
Path or directory for saving the ZIP file. If no .zip extension, the default filename is appended. Defaults to cwd. |
None
|
Returns:
| Type | Description |
|---|---|
Optional[Path]
|
Path to the created ZIP file, or None on non-primary distributed ranks. |
wait_for_save(interval=3)
upload()
query(n_data, strategy, dataset_type='pool', type_cast=None)
Sort and query dataset based on strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Maximum number of data to query. |
required |
|
Union[str, FilterConfig, Dict[str, Any]]
|
Query strategy. - 'high_error_proba': Sort by error probability descending - 'low_error_proba': Sort by error probability ascending - FilterConfig or dict: Filter by conditions and return matching data |
required |
|
str
|
Filter of input dataset. (e.g. 'pool') |
'pool'
|
|
Optional[type]
|
Type for casting input_id. (e.g. int) |
None
|
Returns:
| Type | Description |
|---|---|
List[str]
|
List[str]: List of input_id of queried dataset. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If strategy is invalid. |
Examples:
>>> # Sort by high error probability
>>> ids = data_filter.query(n_data=10, strategy='high_error_proba')
>>>
>>> # Filter by zone
>>> ids = data_filter.query(n_data=10, strategy={'target_zones': ['hotspot']})
>>>
>>> # Filter by conditions
>>> ids = data_filter.query(n_data=10, strategy={'conditions': [{'error_proba': '>=0.8'}]})