hub/venv/lib/python3.7/site-packages/scipy/fft/_backend.py

156 lines
4.7 KiB
Python

import scipy._lib.uarray as ua
from . import _pocketfft
class _ScipyBackend:
"""The default backend for fft calculations
Notes
-----
We use the domain ``numpy.scipy`` rather than ``scipy`` because in the
future ``uarray`` will treat the domain as a hierarchy. This means the user
can install a single backend for ``numpy`` and have it implement
``numpy.scipy.fft`` as well.
"""
__ua_domain__ = "numpy.scipy.fft"
@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_pocketfft, method.__name__, None)
if fn is None:
return NotImplemented
return fn(*args, **kwargs)
_named_backends = {
'scipy': _ScipyBackend,
}
def _backend_from_arg(backend):
"""Maps strings to known backends and validates the backend"""
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError:
raise ValueError('Unknown backend {}'.format(backend))
if backend.__ua_domain__ != 'numpy.scipy.fft':
raise ValueError('Backend does not implement "numpy.scipy.fft"')
return backend
def set_global_backend(backend):
"""Sets the global fft backend
The global backend has higher priority than registered backends, but lower
priority than context-specific backends set with `set_backend`.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``
Notes
-----
This will overwrite the previously set global backend, which by default is
the SciPy implementation.
"""
backend = _backend_from_arg(backend)
ua.set_global_backend(backend)
def register_backend(backend):
"""
Register a backend for permanent use.
Registered backends have the lowest priority and will be tried after the
global backend.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``
"""
backend = _backend_from_arg(backend)
ua.register_backend(backend)
def set_backend(backend, coerce=False, only=False):
"""Context manager to set the backend within a fixed scope.
Upon entering the ``with`` statement, the given backend will be added to
the list of available backends with the highest priority. Upon exit, the
backend is reset to the state before entering the scope.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
coerce: bool, optional
Whether to allow expensive conversions for the ``x`` parameter. e.g.
copying a numpy array to the GPU for a CuPy backend. Implies ``only``.
only: bool, optional
If only is ``True`` and this backend returns ``NotImplemented`` then a
BackendNotImplemented error will be raised immediately. Ignoring any
lower priority backends.
Examples
--------
>>> import scipy.fft as fft
>>> with fft.set_backend('scipy', only=True):
... fft.fft([1]) # Always calls the scipy implementation
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)
def skip_backend(backend):
"""Context manager to skip a backend within a fixed scope.
Within the context of a ``with`` statement, the given backend will not be
called. This covers backends registered both locally and globally. Upon
exit, the backend will again be considered.
Parameters
----------
backend: {object, 'scipy'}
The backend to skip.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Examples
--------
>>> import scipy.fft as fft
>>> fft.fft([1]) # Calls default scipy backend
array([1.+0.j])
>>> with fft.skip_backend('scipy'): # We expicitly skip the scipy backend
... fft.fft([1]) # leaving no implementation available
Traceback (most recent call last):
...
BackendNotImplementedError: No selected backends had an implementation ...
"""
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)
set_global_backend('scipy')