hub/venv/lib/python3.7/site-packages/zmq/tests/test_auth.py

558 lines
20 KiB
Python
Raw Normal View History

# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import os
import shutil
import sys
import tempfile
import pytest
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator
from zmq.utils.strtypes import u
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
class BaseAuthTestCase(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4,0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to have curve support")
super(BaseAuthTestCase, self).setUp()
# enable debug logging while we run tests
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
self.auth = self.make_auth()
self.auth.start()
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
def make_auth(self):
raise NotImplementedError()
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.remove_certs(self.base_dir)
super(BaseAuthTestCase, self).tearDown()
def create_certs(self):
"""Create CURVE certificates for a test"""
# Create temporary CURVE keypairs for this test run. We create all keys in a
# temp directory and then move them into the appropriate private or public
# directory.
base_dir = tempfile.mkdtemp()
keys_dir = os.path.join(base_dir, 'certificates')
public_keys_dir = os.path.join(base_dir, 'public_keys')
secret_keys_dir = os.path.join(base_dir, 'private_keys')
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(public_keys_dir, '.'))
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(secret_keys_dir, '.'))
return (base_dir, public_keys_dir, secret_keys_dir)
def remove_certs(self, base_dir):
"""Remove certificates for a test"""
shutil.rmtree(base_dir)
def load_certs(self, secret_keys_dir):
"""Return server and client certificate keys"""
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
return server_public, server_secret, client_public, client_secret
class TestThreadAuthentication(BaseAuthTestCase):
"""Test authentication running in a thread"""
def make_auth(self):
return ThreadAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
if server.poll(1000, zmq.POLLOUT):
server.send_multipart(msg)
if client.poll(1000):
rcvd_msg = client.recv_multipart()
self.assertEqual(rcvd_msg, msg)
result = True
return result
def test_null(self):
"""threaded auth - NULL"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
self.auth.stop()
self.auth = None
# use a new context, so ZAP isn't inherited
self.context = self.Context()
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
server = self.socket(zmq.PUSH)
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_blacklist(self):
"""threaded auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertFalse(self.can_connect(server, client))
def test_whitelist(self):
"""threaded auth - Whitelist"""
# Whitelist 127.0.0.1, connection should pass"
self.auth.allow('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_plain(self):
"""threaded auth - PLAIN"""
# Try PLAIN authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.assertFalse(self.can_connect(server, client))
# Try PLAIN authentication - with server configured, connection should pass
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
self.assertTrue(self.can_connect(server, client))
# Try PLAIN authentication - with bogus credentials, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Bogus'
self.assertFalse(self.can_connect(server, client))
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
client.close()
server.close()
def test_curve(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
#Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
#Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertTrue(self.can_connect(server, client))
# Try CURVE authentication - with server configured, connection should pass
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
# Try connecting using NULL and no authentication enabled, connection should pass
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_curve_callback(self):
"""threaded auth - CURVE with callback authentication"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
#Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
#Try CURVE authentication - with callback authentication configured, connection should pass
class CredentialsProvider(object):
def __init__(self):
self.client = client_public
def callback(self, domain, key):
if (key == self.client):
return True
else:
return False
provider = CredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertTrue(self.can_connect(server, client))
#Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
class WrongCredentialsProvider(object):
def __init__(self):
self.client = "WrongCredentials"
def callback(self, domain, key):
if (key == self.client):
return True
else:
return False
provider = WrongCredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
@skip_pypy
def test_curve_user_id(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# test default user-id map
client.send(b'test')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == u(client_public)
# test custom user-id map
self.auth.curve_user_id = lambda client_key: u'custom'
client2 = self.socket(zmq.PUSH)
client2.curve_publickey = client_public
client2.curve_secretkey = client_secret
client2.curve_serverkey = server_public
assert self.can_connect(client2, server)
client2.send(b'test2')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test2'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == u'custom'
def with_ioloop(method, expect_success=True):
"""decorator for running tests with an IOLoop"""
def test_method(self):
r = method(self)
loop = self.io_loop
if expect_success:
self.pullstream.on_recv(self.on_message_succeed)
else:
self.pullstream.on_recv(self.on_message_fail)
loop.call_later(1, self.attempt_connection)
loop.call_later(1.2, self.send_msg)
if expect_success:
loop.call_later(2, self.on_test_timeout_fail)
else:
loop.call_later(2, self.on_test_timeout_succeed)
loop.start()
if self.fail_msg:
self.fail(self.fail_msg)
return r
return test_method
def should_auth(method):
return with_ioloop(method, True)
def should_not_auth(method):
return with_ioloop(method, False)
class TestIOLoopAuthentication(BaseAuthTestCase):
"""Test authentication running in ioloop"""
def setUp(self):
try:
from tornado import ioloop
except ImportError:
pytest.skip("Requires tornado")
from zmq.eventloop import zmqstream
self.fail_msg = None
self.io_loop = ioloop.IOLoop()
super(TestIOLoopAuthentication, self).setUp()
self.server = self.socket(zmq.PUSH)
self.client = self.socket(zmq.PULL)
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
def make_auth(self):
from zmq.auth.ioloop import IOLoopAuthenticator
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.io_loop.close(all_fds=True)
super(TestIOLoopAuthentication, self).tearDown()
def attempt_connection(self):
"""Check if client can connect to server using tcp transport"""
iface = 'tcp://127.0.0.1'
port = self.server.bind_to_random_port(iface)
self.client.connect("%s:%i" % (iface, port))
def send_msg(self):
"""Send a message from server to a client"""
msg = [b"Hello World"]
self.pushstream.send_multipart(msg)
def on_message_succeed(self, frames):
"""A message was received, as expected."""
if frames != [b"Hello World"]:
self.fail_msg = "Unexpected message received"
self.io_loop.stop()
def on_message_fail(self, frames):
"""A message was received, unexpectedly."""
self.fail_msg = 'Received messaged unexpectedly, security failed'
self.io_loop.stop()
def on_test_timeout_succeed(self):
"""Test timer expired, indicates test success"""
self.io_loop.stop()
def on_test_timeout_fail(self):
"""Test timer expired, indicates test failure"""
self.fail_msg = 'Test timed out'
self.io_loop.stop()
@should_auth
def test_none(self):
"""ioloop auth - NONE"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
# no auth should be running
self.auth.stop()
self.auth = None
@should_auth
def test_null(self):
"""ioloop auth - NULL"""
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
self.server.zap_domain = b'global'
@should_not_auth
def test_blacklist(self):
"""ioloop auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
self.server.zap_domain = b'global'
@should_auth
def test_whitelist(self):
"""ioloop auth - Whitelist"""
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
self.auth.allow('127.0.0.1')
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
@should_not_auth
def test_plain_unconfigured_server(self):
"""ioloop auth - PLAIN, unconfigured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - without configuring server, connection should fail
self.server.plain_server = True
@should_auth
def test_plain_configured_server(self):
"""ioloop auth - PLAIN, configured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - with server configured, connection should pass
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_plain_bogus_credentials(self):
"""ioloop auth - PLAIN, bogus credentials"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Bogus'
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_curve_unconfigured_server(self):
"""ioloop auth - CURVE, unconfigured server"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_allow_any(self):
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_configured_server(self):
"""ioloop auth - CURVE, configured server"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public