diff --git a/requirements.txt b/requirements.txt index ec5412d..f894311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ python-dotenv SQLAlchemy cerc-hub psycopg2-binary +sqlalchemy_utils \ No newline at end of file diff --git a/tests/test_db_factory.py b/tests/test_db_factory.py index 7eff3d2..6335d21 100644 --- a/tests/test_db_factory.py +++ b/tests/test_db_factory.py @@ -14,9 +14,12 @@ import unittest from pathlib import Path from unittest import TestCase +import psycopg2 import sqlalchemy.exc from sqlalchemy import create_engine from sqlalchemy.exc import ProgrammingError +from sqlalchemy_utils import database_exists, create_database, drop_database + import hub.helpers.constants as cte from hub.exports.energy_building_exports_factory import EnergyBuildingsExportsFactory @@ -54,22 +57,14 @@ class Control: dotenv_path = str(dotenv_path) repository = Repository(db_name='test_db', app_env='TEST', dotenv_path=dotenv_path) engine = create_engine(repository.configuration.connection_string) - try: - # delete test database if it exists - connection = engine.connect() - connection.close() - except ProgrammingError: - logging.info('Database does not exist. Nothing to delete') - except sqlalchemy.exc.OperationalError as operational_error: - self._skip_test = True - self._skip_reason = f'{operational_error}' - return - - Application.__table__.create(bind=repository.engine, checkfirst=True) - User.__table__.create(bind=repository.engine, checkfirst=True) - City.__table__.create(bind=repository.engine, checkfirst=True) - CityObject.__table__.create(bind=repository.engine, checkfirst=True) - SimulationResults.__table__.create(bind=repository.engine, checkfirst=True) + if database_exists(engine.url): + drop_database(engine.url) + create_database(engine.url) + Application.__table__.create(bind=engine, checkfirst=True) + User.__table__.create(bind=engine, checkfirst=True) + City.__table__.create(bind=engine, checkfirst=True) + CityObject.__table__.create(bind=engine, checkfirst=True) + SimulationResults.__table__.create(bind=engine, checkfirst=True) city_file = Path('tests_data/test.geojson').resolve() output_path = Path('tests_outputs/').resolve()