Source code for nni.algorithms.compression.v2.pytorch.base.scheduler

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import gc
import logging
import os
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union

import json_tricks
import torch
from torch import Tensor
from torch.nn import Module

_logger = logging.getLogger(__name__)

class Task:
    # NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
    _reference_counter = {}

    def __init__(self, task_id: int, model_path: Union[str, Path], masks_path: Union[str, Path], config_list_path: Union[str, Path],
                 speedup: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True):
            The unique id of task.
            The path of the unwrapped pytorch model that will be pruned in this task.
            The path of the masks that applied on the model before pruning.
            The path of the config list that used in this task.
            Control if this task needs speedup, True means use scheduler default value, False means no speedup.
            Control if this task needs finetune, True means use scheduler default value, False means no finetune.
            Control if this task needs evaluate, True means use scheduler default value, False means no evaluate.
        self.task_id = task_id
        self.model_path = model_path
        self.masks_path = masks_path
        self.config_list_path = config_list_path

        self.speedup = speedup
        self.finetune = finetune
        self.evaluate = evaluate

        self.status = 'Pending'
        self.score: Optional[float] = None

        self.state = {}

        for ref in self.referenced_paths():
            self._reference_counter.setdefault(ref, 0)
            self._reference_counter[ref] += 1

        self._cleaned = False

    def to_dict(self) -> Dict:
        return {
            'task_id': self.task_id,
            'model_path': str(self.model_path),
            'masks_path': str(self.masks_path),
            'config_list_path': str(self.config_list_path),
            'speedup': self.speedup,
            'finetune': self.finetune,
            'evaluate': self.evaluate,
            'status': self.status,
            'score': self.score,
            'state': self.state

    def load_data(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]:
        Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]
            Return the model pruning in this task, the masks of the model before pruning,
            the config list used in this task.
        model = torch.load(self.model_path)
        masks = torch.load(self.masks_path)
        with Path(self.config_list_path).open('r') as f:
            config_list = json_tricks.load(f)
        return model, masks, config_list

    def referenced_paths(self) -> List[Union[str, Path]]:
        Return the path list that need to count reference in this task.
        return [self.model_path, self.masks_path, self.config_list_path]

    def clean_up(self):
        Counter of referenced file paths subtract 1. If the counter reach 0, then delete the file.
        if not self._cleaned:
            for ref in self.referenced_paths():
                self._reference_counter[ref] -= 1
                if self._reference_counter[ref] <= 0:
                    if self._reference_counter[ref] < 0:
                        _logger.warning('Referance counter error, the number of %s is %d',
                                        ref, self._reference_counter[ref])
            self._cleaned = True
            _logger.warning('Already clean up task %d', self.task_id)

class TaskResult:
    def __init__(self, task_id: Union[int, str], compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
                 pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
            The unique id of task.
            The unwrapped compact pytorch model after pruning. If the compact model has been speeduped during the pruning process,
            it will have a smaller structure compare with the model before pruning.
            If the compact model has not been speeduped, it will have the same structure with the model before pruning.
            The masks on the compact model. If the compact model has been speeduped during the pruning process,
            the `compact_model_masks` is always an empty dict. If the compact model has not been speeduped,
            the `compact_model_masks` is same as `pruner_generated_masks`.
            The masks that can apply on the before pruning model. It is always the output of `pruner.compress()`.
            TODO: If the compact model has been speeduped, the auto infer masks maybe also need.
            The score of the pruning effect. i.e., the accuracy or latency after pruning.
        self.task_id = task_id
        self.compact_model = compact_model
        self.compact_model_masks = compact_model_masks
        self.pruner_generated_masks = pruner_generated_masks
        self.score = score

[docs]class BasePruningScheduler:
[docs] def generate_task(self) -> Optional[Task]: """ Returns ------- Optional[Task] Return the next pruning task. """ raise NotImplementedError()
[docs] def record_task_result(self, task_result: TaskResult): """ Parameters ---------- task_result The result of the task """ raise NotImplementedError()
[docs] def pruning_one_step(self, task: Task) -> TaskResult: """ Pruning the model defined in task. Parameters ---------- task The pruning task in this step. Returns ------- TaskResult Return the result of the task in this step. """ raise NotImplementedError()
[docs] def get_best_result(self) -> Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]: """ Returns ------- Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]] Return the task result that has the best performance, inculde task id, the compact model, the masks on the compact model, score and config list used in this task. """ raise NotImplementedError()
[docs] def compress(self): """ The pruning schedule main loop. """ task = self.generate_task() while task is not None: task_result = self.pruning_one_step(task) self.record_task_result(task_result) del task_result gc.collect() task = self.generate_task()