nni.nas.benchmark.utils 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import functools
import json
import os

from playhouse.sqlite_ext import SqliteExtDatabase

from nni.common.blob_utils import load_or_download_file

from .constants import DB_URLS, DATABASE_DIR


json_dumps = functools.partial(json.dumps, sort_keys=True)

# to prevent repetitive loading of benchmarks
_loaded_benchmarks = {}


[文档] def load_benchmark(benchmark: str) -> SqliteExtDatabase: """ Load a benchmark as a database. Parmaeters ---------- benchmark : str Benchmark name like nasbench201. """ if benchmark in _loaded_benchmarks: return _loaded_benchmarks[benchmark] url = DB_URLS[benchmark] local_path = os.path.join(DATABASE_DIR, os.path.basename(url)) try: load_or_download_file(local_path, url) except FileNotFoundError: raise FileNotFoundError( f'Please use `nni.nas.benchmark.download_benchmark("{benchmark}")` to setup the benchmark first before using it.' ) _loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True) return _loaded_benchmarks[benchmark]
[文档] def download_benchmark(benchmark: str, progress: bool = True): """ Download a converted benchmark. Parameters ---------- benchmark : str Benchmark name like nasbench201. """ url = DB_URLS[benchmark] local_path = os.path.join(DATABASE_DIR, os.path.basename(url)) load_or_download_file(local_path, url, True, progress)