diff --git a/db_migration.py b/db_migration.py deleted file mode 100644 index 30021a57..00000000 --- a/db_migration.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Migration script to create postgresql tables -SPDX - License - Identifier: LGPL - 3.0 - or -later -Copyright © 2022 Concordia CERC group -Project Coder Peter Yefi peteryefi@gmail.com -""" - -from sqlalchemy import create_engine -from persistence.db_config import BaseConfiguration -from persistence.models import Building -from persistence.models import City - -if __name__ == '__main__': - config = BaseConfiguration(db_name='peteryefi') - engine = create_engine(config.conn_string()) - City.__table__.create(bind=engine, checkfirst=True) - diff --git a/exports/db_factory.py b/exports/db_factory.py new file mode 100644 index 00000000..1ec6861f --- /dev/null +++ b/exports/db_factory.py @@ -0,0 +1,32 @@ +""" +DBFactory performs read related operations +SPDX - License - Identifier: LGPL - 3.0 - or -later +Copyright © 2022 Concordia CERC group +Project CoderPeter Yefi peteryefi@gmail.com +""" +from persistence import CityRepo + + +class DBFactory: + """ + DBFactory class + """ + + def __init__(self, city, db_name, app_env): + self._city = city + self._city_repo = CityRepo(db_name=db_name, app_env=app_env) + + def get_city(self, city_id): + """ + Retrieve a single city from postgres + :param city_id: the id of the city to get + """ + return self._city_repo.get_by_id(city_id) + + def get_city_by_name(self, city_name): + """ + Retrieve a single city from postgres + :param city_name: the name of the city to get + """ + return self._city_repo.get_by_name(city_name) + diff --git a/imports/__init__.py b/imports/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/imports/db_factory.py b/imports/db_factory.py new file mode 100644 index 00000000..bd6ca971 --- /dev/null +++ b/imports/db_factory.py @@ -0,0 +1,39 @@ +""" +DBFactory performs database create, delete and update operations +SPDX - License - Identifier: LGPL - 3.0 - or -later +Copyright © 2022 Concordia CERC group +Project CoderPeter Yefi peteryefi@gmail.com +""" +from persistence import CityRepo + + +class DBFactory: + """ + DBFactory class + """ + + def __init__(self, city, db_name, app_env): + self._city = city + self._city_repo = CityRepo(db_name=db_name, app_env=app_env) + + def persist_city(self): + """ + Persist city into postgres database + """ + return self._city_repo.insert(self._city) + + def update_city(self, city_id, city): + """ + Update an existing city in postgres database + :param city_id: the id of the city to update + :param city: the updated city object + """ + return self._city_repo.update(city_id, city) + + def delete_city(self, city_id): + """ + Deletes a single city from postgres + :param city_id: the id of the city to get + """ + self._city_repo.delete_city(city_id) + diff --git a/persistence/README.md b/persistence/README.md index d7312f44..7517be25 100644 --- a/persistence/README.md +++ b/persistence/README.md @@ -6,7 +6,15 @@ This defines models for all class objects that we want to persist. It is used fo of the class objects to database table columns ### repositories ### -This defines repository classes that contain CRUD methods for database operations. +This defines repository classes that contain CRUD methods for database operations. The constructor of all repositories requires +The database name to connect to and the application environment (PROD or TEST). Tests use a different database +from the production environment, which is why this is necessary. An example is shown below +```python +from persistence import CityRepo +# instantiate city repo for hub production database +city_repo = CityRepo(db_name='hub', app_env='PROD') +``` +All database operations are conducted with the production database (*PROD*) named *hub* in the example above ### config_db ## This Python file is a configuration class that contains variables that map to configuration parameters in a .env file. @@ -15,14 +23,18 @@ It also contains a method ``def conn_string()`` which returns the connection str ### Base ## This class has a constructor that establishes a database connection and returns a reference for database-related CRUD operations. -### db_migration ### -This Python file is in the root of Hub and should be run to create all the required Postgres database tables - ### Database Configuration Parameter ### A .env file (or environment variables) with configuration parameters described below are needed to establish a database connection: ``` -DB_USER=postgres-database-user -DB_PASSWORD=postgres-database-password -DB_HOST=database-host -DB_PORT=database-port +# production database credentials +PROD_DB_USER=postgres-database-user +PROD_DB_PASSWORD=postgres-database-password +PROD_DB_HOST=database-host +PROD_DB_PORT=database-port + +# test database credentials +TEST_DB_USER=postgres-database-user +TEST_DB_PASSWORD=postgres-database-password +TEST_DB_HOST=database-host +TEST_DB_PORT=database-port ``` diff --git a/persistence/__init__.py b/persistence/__init__.py index cbd57dff..116dfd91 100644 --- a/persistence/__init__.py +++ b/persistence/__init__.py @@ -1,4 +1,3 @@ from .base import BaseRepo from .repositories.city_repo import CityRepo from .repositories.building_repo import BuildingRepo -from .repositories.city_repo import CityRepo \ No newline at end of file diff --git a/persistence/base.py b/persistence/base.py index 89478772..08e1103a 100644 --- a/persistence/base.py +++ b/persistence/base.py @@ -11,8 +11,20 @@ from sqlalchemy.orm import Session class BaseRepo: - def __init__(self, db_name): - config = BaseConfiguration(db_name) + + def __init__(self, db_name, app_env='TEST'): + config = BaseConfiguration(db_name, app_env) engine = create_engine(config.conn_string()) self.config = config self.session = Session(engine) + + def __del__(self): + """ + Close database sessions + :return: + """ + self.session.close() + + + + diff --git a/persistence/db_config.py b/persistence/db_config.py index ab8afc07..a7f378e5 100644 --- a/persistence/db_config.py +++ b/persistence/db_config.py @@ -20,13 +20,17 @@ class BaseConfiguration(object): """ Base configuration class to hold common persistence configuration """ - def __init__(self, db_name: str): + def __init__(self, db_name: str, app_env='TEST'): + """ + :param db_name: database name + :param app_env: application environment, test or production + """ self._db_name = db_name - self._db_host = os.getenv('DB_HOST') - self._db_user = os.getenv('DB_USER') - self._db_pass = os.getenv('DB_PASSWORD') - self._db_port = os.getenv('DB_PORT') - self.hub_token = os.getenv('HUB_TOKEN') + self._db_host = os.getenv(f'{app_env}_DB_HOST') + self._db_user = os.getenv(f'{app_env}_DB_USER') + self._db_pass = os.getenv(f'{app_env}_DB_PASSWORD') + self._db_port = os.getenv(f'{app_env}_DB_PORT') + self.hub_token = os.getenv(f'{app_env}_HUB_TOKEN') def conn_string(self): """ diff --git a/persistence/repositories/city_repo.py b/persistence/repositories/city_repo.py index 2e474b16..4e3cb47e 100644 --- a/persistence/repositories/city_repo.py +++ b/persistence/repositories/city_repo.py @@ -10,7 +10,7 @@ from persistence import BaseRepo from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import select from helpers.city_util import CityUtil -from persistence.models import City +from persistence.models import City as DBCity import pickle import requests from urllib3.exceptions import HTTPError @@ -18,22 +18,32 @@ from typing import Union, Dict class CityRepo(BaseRepo): - def __init__(self, db_name): - super().__init__(db_name) + _instance = None + + def __init__(self, db_name, app_env): + super().__init__(db_name, app_env) self._city_util = CityUtil() + def __new__(cls, db_name, app_env): + """ + Implemented for a singleton pattern + """ + if cls._instance is None: + cls._instance = super(CityRepo, cls).__new__(cls) + return cls._instance + def insert(self, city: City) -> Union[City, Dict]: - model_city = City() - model_city.name = city.name - model_city.climate_reference_city = city.climate_reference_city - model_city.srs_name = city.srs_name - model_city.longitude = city.longitude - model_city.latitude = city.latitude - model_city.country_code = city.country_code - model_city.time_zone = city.time_zone - model_city.lower_corner = city.lower_corner.tolist() - model_city.upper_corner = city.upper_corner.tolist() - model_city.city = pickle.dumps(city) + db_city = DBCity() + db_city.name = city.name + db_city.climate_reference_city = city.climate_reference_city + db_city.srs_name = city.srs_name + db_city.longitude = city.longitude + db_city.latitude = city.latitude + db_city.country_code = city.country_code + db_city.time_zone = city.time_zone + db_city.lower_corner = city.lower_corner.tolist() + db_city.upper_corner = city.upper_corner.tolist() + db_city.city = pickle.dumps(city) try: # Retrieve hub project latest release @@ -44,19 +54,20 @@ class CityRepo(BaseRepo): # Do not persist the same city for the same version of Hub if exiting_city is None: - model_city.hub_release = recent_commit + db_city.hub_release = recent_commit cities = self.get_by_name(city.name) # update version for the same city but different hub versions + if len(cities) == 0: - model_city.city_version = 0 + db_city.city_version = 0 else: - model_city.city_version = cities[-1].city_version + 1 + db_city.city_version = cities[-1].city_version + 1 # Persist city - self.session.add(model_city) + self.session.add(db_city) self.session.flush() self.session.commit() - return model_city + return db_city else: return {'message': f'Same version of {city.name} exist'} except SQLAlchemyError as err: @@ -64,14 +75,14 @@ class CityRepo(BaseRepo): except HTTPError as err: print(f'Error retrieving Hub latest release: {err}') - def get_by_id(self, city_id: int) -> City: + def get_by_id(self, city_id: int) -> DBCity: """ Fetch a City based on the id :param city_id: the city id :return: a city """ try: - return self.session.execute(select(City).where(City.id == city_id)).first()[0] + return self.session.execute(select(DBCity).where(DBCity.id == city_id)).first()[0] except SQLAlchemyError as err: print(f'Error while fetching city: {err}') @@ -83,8 +94,8 @@ class CityRepo(BaseRepo): :return: a city """ try: - return self.session.execute(select(City) - .where(City.hub_release == hub_commit, City.name == city_name)).first() + return self.session.execute(select(DBCity) + .where(DBCity.hub_release == hub_commit, DBCity.name == city_name)).first() except SQLAlchemyError as err: print(f'Error while fetching city: {err}') @@ -96,7 +107,7 @@ class CityRepo(BaseRepo): :return: """ try: - self.session.query(City).filter(City.id == city_id) \ + self.session.query(DBCity).filter(DBCity.id == city_id) \ .update({ 'name': city.name, 'srs_name': city.srs_name, 'country_code': city.country_code, 'longitude': city.longitude, 'latitude': city.latitude, 'time_zone': city.time_zone, 'lower_corner': city.lower_corner.tolist(), @@ -107,14 +118,27 @@ class CityRepo(BaseRepo): except SQLAlchemyError as err: print(f'Error while updating city: {err}') - def get_by_name(self, city_name: str) -> [City]: + def get_by_name(self, city_name: str) -> [DBCity]: """ Fetch city based on the name :param city_name: the name of the building :return: [ModelCity] with the provided name """ try: - result_set = self.session.execute(select(City).where(City.name == city_name)) + result_set = self.session.execute(select(DBCity).where(DBCity.name == city_name)) return [building[0] for building in result_set] except SQLAlchemyError as err: print(f'Error while fetching city by name: {err}') + + def delete_city(self, city_id: int): + """ + Deletes a City with the id + :param city_id: the city id + :return: a city + """ + try: + self.session.query(DBCity).filter(DBCity.id == city_id).delete() + self.session.commit() + except SQLAlchemyError as err: + print(f'Error while fetching city: {err}') + diff --git a/unittests/test_db_factory.py b/unittests/test_db_factory.py new file mode 100644 index 00000000..cdfe2ae1 --- /dev/null +++ b/unittests/test_db_factory.py @@ -0,0 +1,76 @@ +""" +Test EnergySystemsFactory and various heatpump models +SPDX - License - Identifier: LGPL - 3.0 - or -later +Copyright © 2022 Concordia CERC group +Project Coder Peter Yefi peteryefi@gmail.com +""" +from unittest import TestCase +from imports.geometry_factory import GeometryFactory +from imports.db_factory import DBFactory +from exports.db_factory import DBFactory as ExportDBFactory +from persistence.db_config import BaseConfiguration +from sqlalchemy import create_engine +from persistence.models import City +from pickle import loads + + +class TestDBFactory(TestCase): + """ + TestDBFactory + """ + + def setUp(self) -> None: + """ + Test setup + :return: None + """ + # Create test tables if they do not exit + config = BaseConfiguration(db_name='hub_test', app_env='TEST') + engine = create_engine(config.conn_string()) + City.__table__.create(bind=engine, checkfirst=True) + + city_file = "../unittests/tests_data/C40_Final.gml" + self.city = GeometryFactory('citygml', city_file).city + self._db_factory = DBFactory(city=self.city, db_name='hub_test', app_env='TEST') + self._export_db_factory = ExportDBFactory(city=self.city, db_name='hub_test', app_env='TEST') + + def test_save_city(self): + self._saved_city = self._db_factory.persist_city() + self.assertEqual(self._saved_city.name, 'Montréal') + pickled_city = loads(self._saved_city.city) + self.assertEqual(len(pickled_city.buildings), 10) + self.assertEqual(pickled_city.buildings[0].floor_area, 1990.9913970530033) + self._db_factory.delete_city(self._saved_city.id) + + def test_save_same_city_with_same_hub_version(self): + first_city = self._db_factory.persist_city() + second_city = self._db_factory.persist_city() + self.assertEqual(second_city['message'], f'Same version of {self.city.name} exist') + self.assertEqual(first_city.name, 'Montréal') + self.assertEqual(first_city.country_code, 'ca') + self._db_factory.delete_city(first_city.id) + + def test_get_city_by_name(self): + city = self._db_factory.persist_city() + retrieved_city = self._export_db_factory.get_city_by_name(city.name) + self.assertEqual(retrieved_city[0].lower_corner[0], 610610.7547462888) + self._db_factory.delete_city(city.id) + + def test_get_city_by_id(self): + city = self._db_factory.persist_city() + retrieved_city = self._export_db_factory.get_city(city.id) + self.assertEqual(retrieved_city.upper_corner[0], 610818.6731258357) + self._db_factory.delete_city(city.id) + + def test_get_update_city(self): + city = self._db_factory.persist_city() + self.city.longitude = 1.43589 + self.city.latitude = -9.38928339 + self._db_factory.update_city(city.id, self.city) + updated_city = self._export_db_factory.get_city(city.id) + self.assertEqual(updated_city.longitude, 1.43589) + self.assertEqual(updated_city.latitude, -9.38928339) + self._db_factory.delete_city(city.id) + + +