nni.compression.utils.dependency 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List
import uuid
import torch
from .shape_dependency import ChannelDependency, GroupDependency
from ..base.config import select_modules_by_config, trans_legacy_config_list
[文档]
def auto_set_denpendency_group_ids(model: torch.nn.Module, config_list: List[Dict[str, Any]],
dummy_input: torch.Tensor | List[torch.Tensor] | Dict[str, torch.Tensor]) -> List[Dict[str, Any]]:
"""
Auto find the output dependency between all 'Conv2d', 'Linear', 'ConvTranspose2d', 'Embedding' modules,
then set the ``dependency_group_id`` in config list.
Note that a new dependency group id will be set as a shortcut in one config,
it will replace the old configured one in that config.
Parameters
----------
model
The origin model.
config_list
The compression config list.
dummy_input
The dummy input to the model forward function for tracing the model.
"""
dependency = ChannelDependency(model, dummy_input)
dependency.build_dependency()
module2uid = {}
for dependency_set in dependency.dependency_sets:
uid = uuid.uuid4().hex
module2uid.update({name: uid for name in dependency_set})
group_dependency = GroupDependency(model, dummy_input)
group_dependency.build_dependency()
config_list = trans_legacy_config_list(config_list)
new_config_list = []
for config in config_list:
modules, public_config, _ = select_modules_by_config(model, config)
for name in modules.keys():
sub_config = deepcopy(public_config)
if name in module2uid:
sub_config['dependency_group_id'] = module2uid[name]
if name in group_dependency.dependency:
sub_config['internal_metric_block'] = int(group_dependency.dependency[name])
new_config_list.append({
'op_names': [name],
**sub_config
})
return new_config_list