95 lines
2.8 KiB
Python
95 lines
2.8 KiB
Python
"""A shim module for deprecated imports
|
|
"""
|
|
# Copyright (c) IPython Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import sys
|
|
import types
|
|
from importlib import import_module
|
|
|
|
from .importstring import import_item
|
|
|
|
|
|
class ShimWarning(Warning):
|
|
"""A warning to show when a module has moved, and a shim is in its place."""
|
|
|
|
class ShimImporter(object):
|
|
"""Import hook for a shim.
|
|
|
|
This ensures that submodule imports return the real target module,
|
|
not a clone that will confuse `is` and `isinstance` checks.
|
|
"""
|
|
def __init__(self, src, mirror):
|
|
self.src = src
|
|
self.mirror = mirror
|
|
|
|
def _mirror_name(self, fullname):
|
|
"""get the name of the mirrored module"""
|
|
|
|
return self.mirror + fullname[len(self.src):]
|
|
|
|
def find_module(self, fullname, path=None):
|
|
"""Return self if we should be used to import the module."""
|
|
if fullname.startswith(self.src + '.'):
|
|
mirror_name = self._mirror_name(fullname)
|
|
try:
|
|
mod = import_item(mirror_name)
|
|
except ImportError:
|
|
return
|
|
else:
|
|
if not isinstance(mod, types.ModuleType):
|
|
# not a module
|
|
return None
|
|
return self
|
|
|
|
def load_module(self, fullname):
|
|
"""Import the mirrored module, and insert it into sys.modules"""
|
|
mirror_name = self._mirror_name(fullname)
|
|
mod = import_item(mirror_name)
|
|
sys.modules[fullname] = mod
|
|
return mod
|
|
|
|
|
|
class ShimModule(types.ModuleType):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._mirror = kwargs.pop("mirror")
|
|
src = kwargs.pop("src", None)
|
|
if src:
|
|
kwargs['name'] = src.rsplit('.', 1)[-1]
|
|
super(ShimModule, self).__init__(*args, **kwargs)
|
|
# add import hook for descendent modules
|
|
if src:
|
|
sys.meta_path.append(
|
|
ShimImporter(src=src, mirror=self._mirror)
|
|
)
|
|
|
|
@property
|
|
def __path__(self):
|
|
return []
|
|
|
|
@property
|
|
def __spec__(self):
|
|
"""Don't produce __spec__ until requested"""
|
|
return import_module(self._mirror).__spec__
|
|
|
|
def __dir__(self):
|
|
return dir(import_module(self._mirror))
|
|
|
|
@property
|
|
def __all__(self):
|
|
"""Ensure __all__ is always defined"""
|
|
mod = import_module(self._mirror)
|
|
try:
|
|
return mod.__all__
|
|
except AttributeError:
|
|
return [name for name in dir(mod) if not name.startswith('_')]
|
|
|
|
def __getattr__(self, key):
|
|
# Use the equivalent of import_item(name), see below
|
|
name = "%s.%s" % (self._mirror, key)
|
|
try:
|
|
return import_item(name)
|
|
except ImportError:
|
|
raise AttributeError(key)
|