Included tests, import and export for db operations

This commit is contained in:
Peter Yefi 2022-11-15 20:48:42 -05:00
parent 1881c1db78
commit 43d7d1ee77
10 changed files with 241 additions and 60 deletions

View File

@ -1,17 +0,0 @@
"""
Migration script to create postgresql tables
SPDX - License - Identifier: LGPL - 3.0 - or -later
Copyright © 2022 Concordia CERC group
Project Coder Peter Yefi peteryefi@gmail.com
"""
from sqlalchemy import create_engine
from persistence.db_config import BaseConfiguration
from persistence.models import Building
from persistence.models import City
if __name__ == '__main__':
config = BaseConfiguration(db_name='peteryefi')
engine = create_engine(config.conn_string())
City.__table__.create(bind=engine, checkfirst=True)

32
exports/db_factory.py Normal file
View File

@ -0,0 +1,32 @@
"""
DBFactory performs read related operations
SPDX - License - Identifier: LGPL - 3.0 - or -later
Copyright © 2022 Concordia CERC group
Project CoderPeter Yefi peteryefi@gmail.com
"""
from persistence import CityRepo
class DBFactory:
"""
DBFactory class
"""
def __init__(self, city, db_name, app_env):
self._city = city
self._city_repo = CityRepo(db_name=db_name, app_env=app_env)
def get_city(self, city_id):
"""
Retrieve a single city from postgres
:param city_id: the id of the city to get
"""
return self._city_repo.get_by_id(city_id)
def get_city_by_name(self, city_name):
"""
Retrieve a single city from postgres
:param city_name: the name of the city to get
"""
return self._city_repo.get_by_name(city_name)

0
imports/__init__.py Normal file
View File

39
imports/db_factory.py Normal file
View File

@ -0,0 +1,39 @@
"""
DBFactory performs database create, delete and update operations
SPDX - License - Identifier: LGPL - 3.0 - or -later
Copyright © 2022 Concordia CERC group
Project CoderPeter Yefi peteryefi@gmail.com
"""
from persistence import CityRepo
class DBFactory:
"""
DBFactory class
"""
def __init__(self, city, db_name, app_env):
self._city = city
self._city_repo = CityRepo(db_name=db_name, app_env=app_env)
def persist_city(self):
"""
Persist city into postgres database
"""
return self._city_repo.insert(self._city)
def update_city(self, city_id, city):
"""
Update an existing city in postgres database
:param city_id: the id of the city to update
:param city: the updated city object
"""
return self._city_repo.update(city_id, city)
def delete_city(self, city_id):
"""
Deletes a single city from postgres
:param city_id: the id of the city to get
"""
self._city_repo.delete_city(city_id)

View File

@ -6,7 +6,15 @@ This defines models for all class objects that we want to persist. It is used fo
of the class objects to database table columns
### repositories ###
This defines repository classes that contain CRUD methods for database operations.
This defines repository classes that contain CRUD methods for database operations. The constructor of all repositories requires
The database name to connect to and the application environment (PROD or TEST). Tests use a different database
from the production environment, which is why this is necessary. An example is shown below
```python
from persistence import CityRepo
# instantiate city repo for hub production database
city_repo = CityRepo(db_name='hub', app_env='PROD')
```
All database operations are conducted with the production database (*PROD*) named *hub* in the example above
### config_db ##
This Python file is a configuration class that contains variables that map to configuration parameters in a .env file.
@ -15,14 +23,18 @@ It also contains a method ``def conn_string()`` which returns the connection str
### Base ##
This class has a constructor that establishes a database connection and returns a reference for database-related CRUD operations.
### db_migration ###
This Python file is in the root of Hub and should be run to create all the required Postgres database tables
### Database Configuration Parameter ###
A .env file (or environment variables) with configuration parameters described below are needed to establish a database connection:
```
DB_USER=postgres-database-user
DB_PASSWORD=postgres-database-password
DB_HOST=database-host
DB_PORT=database-port
# production database credentials
PROD_DB_USER=postgres-database-user
PROD_DB_PASSWORD=postgres-database-password
PROD_DB_HOST=database-host
PROD_DB_PORT=database-port
# test database credentials
TEST_DB_USER=postgres-database-user
TEST_DB_PASSWORD=postgres-database-password
TEST_DB_HOST=database-host
TEST_DB_PORT=database-port
```

View File

@ -1,4 +1,3 @@
from .base import BaseRepo
from .repositories.city_repo import CityRepo
from .repositories.building_repo import BuildingRepo
from .repositories.city_repo import CityRepo

View File

@ -11,8 +11,20 @@ from sqlalchemy.orm import Session
class BaseRepo:
def __init__(self, db_name):
config = BaseConfiguration(db_name)
def __init__(self, db_name, app_env='TEST'):
config = BaseConfiguration(db_name, app_env)
engine = create_engine(config.conn_string())
self.config = config
self.session = Session(engine)
def __del__(self):
"""
Close database sessions
:return:
"""
self.session.close()

View File

@ -20,13 +20,17 @@ class BaseConfiguration(object):
"""
Base configuration class to hold common persistence configuration
"""
def __init__(self, db_name: str):
def __init__(self, db_name: str, app_env='TEST'):
"""
:param db_name: database name
:param app_env: application environment, test or production
"""
self._db_name = db_name
self._db_host = os.getenv('DB_HOST')
self._db_user = os.getenv('DB_USER')
self._db_pass = os.getenv('DB_PASSWORD')
self._db_port = os.getenv('DB_PORT')
self.hub_token = os.getenv('HUB_TOKEN')
self._db_host = os.getenv(f'{app_env}_DB_HOST')
self._db_user = os.getenv(f'{app_env}_DB_USER')
self._db_pass = os.getenv(f'{app_env}_DB_PASSWORD')
self._db_port = os.getenv(f'{app_env}_DB_PORT')
self.hub_token = os.getenv(f'{app_env}_HUB_TOKEN')
def conn_string(self):
"""

View File

@ -10,7 +10,7 @@ from persistence import BaseRepo
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select
from helpers.city_util import CityUtil
from persistence.models import City
from persistence.models import City as DBCity
import pickle
import requests
from urllib3.exceptions import HTTPError
@ -18,22 +18,32 @@ from typing import Union, Dict
class CityRepo(BaseRepo):
def __init__(self, db_name):
super().__init__(db_name)
_instance = None
def __init__(self, db_name, app_env):
super().__init__(db_name, app_env)
self._city_util = CityUtil()
def __new__(cls, db_name, app_env):
"""
Implemented for a singleton pattern
"""
if cls._instance is None:
cls._instance = super(CityRepo, cls).__new__(cls)
return cls._instance
def insert(self, city: City) -> Union[City, Dict]:
model_city = City()
model_city.name = city.name
model_city.climate_reference_city = city.climate_reference_city
model_city.srs_name = city.srs_name
model_city.longitude = city.longitude
model_city.latitude = city.latitude
model_city.country_code = city.country_code
model_city.time_zone = city.time_zone
model_city.lower_corner = city.lower_corner.tolist()
model_city.upper_corner = city.upper_corner.tolist()
model_city.city = pickle.dumps(city)
db_city = DBCity()
db_city.name = city.name
db_city.climate_reference_city = city.climate_reference_city
db_city.srs_name = city.srs_name
db_city.longitude = city.longitude
db_city.latitude = city.latitude
db_city.country_code = city.country_code
db_city.time_zone = city.time_zone
db_city.lower_corner = city.lower_corner.tolist()
db_city.upper_corner = city.upper_corner.tolist()
db_city.city = pickle.dumps(city)
try:
# Retrieve hub project latest release
@ -44,19 +54,20 @@ class CityRepo(BaseRepo):
# Do not persist the same city for the same version of Hub
if exiting_city is None:
model_city.hub_release = recent_commit
db_city.hub_release = recent_commit
cities = self.get_by_name(city.name)
# update version for the same city but different hub versions
if len(cities) == 0:
model_city.city_version = 0
db_city.city_version = 0
else:
model_city.city_version = cities[-1].city_version + 1
db_city.city_version = cities[-1].city_version + 1
# Persist city
self.session.add(model_city)
self.session.add(db_city)
self.session.flush()
self.session.commit()
return model_city
return db_city
else:
return {'message': f'Same version of {city.name} exist'}
except SQLAlchemyError as err:
@ -64,14 +75,14 @@ class CityRepo(BaseRepo):
except HTTPError as err:
print(f'Error retrieving Hub latest release: {err}')
def get_by_id(self, city_id: int) -> City:
def get_by_id(self, city_id: int) -> DBCity:
"""
Fetch a City based on the id
:param city_id: the city id
:return: a city
"""
try:
return self.session.execute(select(City).where(City.id == city_id)).first()[0]
return self.session.execute(select(DBCity).where(DBCity.id == city_id)).first()[0]
except SQLAlchemyError as err:
print(f'Error while fetching city: {err}')
@ -83,8 +94,8 @@ class CityRepo(BaseRepo):
:return: a city
"""
try:
return self.session.execute(select(City)
.where(City.hub_release == hub_commit, City.name == city_name)).first()
return self.session.execute(select(DBCity)
.where(DBCity.hub_release == hub_commit, DBCity.name == city_name)).first()
except SQLAlchemyError as err:
print(f'Error while fetching city: {err}')
@ -96,7 +107,7 @@ class CityRepo(BaseRepo):
:return:
"""
try:
self.session.query(City).filter(City.id == city_id) \
self.session.query(DBCity).filter(DBCity.id == city_id) \
.update({
'name': city.name, 'srs_name': city.srs_name, 'country_code': city.country_code, 'longitude': city.longitude,
'latitude': city.latitude, 'time_zone': city.time_zone, 'lower_corner': city.lower_corner.tolist(),
@ -107,14 +118,27 @@ class CityRepo(BaseRepo):
except SQLAlchemyError as err:
print(f'Error while updating city: {err}')
def get_by_name(self, city_name: str) -> [City]:
def get_by_name(self, city_name: str) -> [DBCity]:
"""
Fetch city based on the name
:param city_name: the name of the building
:return: [ModelCity] with the provided name
"""
try:
result_set = self.session.execute(select(City).where(City.name == city_name))
result_set = self.session.execute(select(DBCity).where(DBCity.name == city_name))
return [building[0] for building in result_set]
except SQLAlchemyError as err:
print(f'Error while fetching city by name: {err}')
def delete_city(self, city_id: int):
"""
Deletes a City with the id
:param city_id: the city id
:return: a city
"""
try:
self.session.query(DBCity).filter(DBCity.id == city_id).delete()
self.session.commit()
except SQLAlchemyError as err:
print(f'Error while fetching city: {err}')

View File

@ -0,0 +1,76 @@
"""
Test EnergySystemsFactory and various heatpump models
SPDX - License - Identifier: LGPL - 3.0 - or -later
Copyright © 2022 Concordia CERC group
Project Coder Peter Yefi peteryefi@gmail.com
"""
from unittest import TestCase
from imports.geometry_factory import GeometryFactory
from imports.db_factory import DBFactory
from exports.db_factory import DBFactory as ExportDBFactory
from persistence.db_config import BaseConfiguration
from sqlalchemy import create_engine
from persistence.models import City
from pickle import loads
class TestDBFactory(TestCase):
"""
TestDBFactory
"""
def setUp(self) -> None:
"""
Test setup
:return: None
"""
# Create test tables if they do not exit
config = BaseConfiguration(db_name='hub_test', app_env='TEST')
engine = create_engine(config.conn_string())
City.__table__.create(bind=engine, checkfirst=True)
city_file = "../unittests/tests_data/C40_Final.gml"
self.city = GeometryFactory('citygml', city_file).city
self._db_factory = DBFactory(city=self.city, db_name='hub_test', app_env='TEST')
self._export_db_factory = ExportDBFactory(city=self.city, db_name='hub_test', app_env='TEST')
def test_save_city(self):
self._saved_city = self._db_factory.persist_city()
self.assertEqual(self._saved_city.name, 'Montréal')
pickled_city = loads(self._saved_city.city)
self.assertEqual(len(pickled_city.buildings), 10)
self.assertEqual(pickled_city.buildings[0].floor_area, 1990.9913970530033)
self._db_factory.delete_city(self._saved_city.id)
def test_save_same_city_with_same_hub_version(self):
first_city = self._db_factory.persist_city()
second_city = self._db_factory.persist_city()
self.assertEqual(second_city['message'], f'Same version of {self.city.name} exist')
self.assertEqual(first_city.name, 'Montréal')
self.assertEqual(first_city.country_code, 'ca')
self._db_factory.delete_city(first_city.id)
def test_get_city_by_name(self):
city = self._db_factory.persist_city()
retrieved_city = self._export_db_factory.get_city_by_name(city.name)
self.assertEqual(retrieved_city[0].lower_corner[0], 610610.7547462888)
self._db_factory.delete_city(city.id)
def test_get_city_by_id(self):
city = self._db_factory.persist_city()
retrieved_city = self._export_db_factory.get_city(city.id)
self.assertEqual(retrieved_city.upper_corner[0], 610818.6731258357)
self._db_factory.delete_city(city.id)
def test_get_update_city(self):
city = self._db_factory.persist_city()
self.city.longitude = 1.43589
self.city.latitude = -9.38928339
self._db_factory.update_city(city.id, self.city)
updated_city = self._export_db_factory.get_city(city.id)
self.assertEqual(updated_city.longitude, 1.43589)
self.assertEqual(updated_city.latitude, -9.38928339)
self._db_factory.delete_city(city.id)