diff --git a/hub/persistence/db_control.py b/hub/persistence/db_control.py index 305ccaed..2d8dd399 100644 --- a/hub/persistence/db_control.py +++ b/hub/persistence/db_control.py @@ -86,25 +86,33 @@ class DBControl: result_names = [] results = {} for city in cities['cities']: - city_name = next(iter(city)) - result_set = self._city_repository.get_by_user_id_application_id_and_name(user_id, application_id, city_name) - if result_set is None: + scenario_name = next(iter(city)) + result_sets = self._city_repository.get_by_user_id_application_id_and_scenario( + user_id, + application_id, + scenario_name + ) + if result_sets is None: continue - city_id = result_set.id - results[city_name] = [] - for building_name in city[city_name]: - if self._city_object.get_by_name_or_alias_and_city(building_name, city_id) is None: - continue - city_object_id = self._city_object.get_by_name_or_alias_and_city(building_name, city_id).id - _ = self._simulation_results.get_simulation_results_by_city_id_city_object_id_and_names( - city_id, - city_object_id, - result_names) + for result_set in result_sets: + city_id = result_set[0].id + print('city ids', city_id) + results[scenario_name] = [] + for building_name in city[scenario_name]: + _building = self._city_object.get_by_name_or_alias_and_city(building_name, city_id) + if _building is None: + continue + city_object_id = _building.id + print('city object ids', city_object_id) + _ = self._simulation_results.get_simulation_results_by_city_id_city_object_id_and_names( + city_id, + city_object_id, + result_names) - for value in _: - values = json.loads(value.values) - values["building"] = building_name - results[city_name].append(values) + for value in _: + values = json.loads(value.values) + values["building"] = building_name + results[scenario_name].append(values) return results def persist_city(self, city: City, pickle_path, scenario, application_id: int, user_id: int): diff --git a/hub/persistence/repositories/city.py b/hub/persistence/repositories/city.py index 1e7d85ae..29048442 100644 --- a/hub/persistence/repositories/city.py +++ b/hub/persistence/repositories/city.py @@ -98,7 +98,7 @@ class City(Repository): logging.error('Error while fetching city %s', err) raise SQLAlchemyError from err - def get_by_user_id_application_id_and_scenario(self, user_id, application_id, scenario) -> Model: + def get_by_user_id_application_id_and_scenario(self, user_id, application_id, scenario) -> [Model]: """ Fetch city based on the user who created it :param user_id: the user id @@ -131,3 +131,4 @@ class City(Repository): except SQLAlchemyError as err: logging.error('Error while fetching city by name %s', err) raise SQLAlchemyError from err + diff --git a/hub/persistence/repositories/city_object.py b/hub/persistence/repositories/city_object.py index 04b9a82a..ae27d134 100644 --- a/hub/persistence/repositories/city_object.py +++ b/hub/persistence/repositories/city_object.py @@ -101,14 +101,23 @@ class CityObject(Repository): Fetch a city object based on name and city id :param name: city object name :param city_id: a city identifier - :return: [CityObject] with the provided name belonging to the city with id city_id + :return: [CityObject] with the provided name or alias belonging to the city with id city_id """ - _city_object = None try: - _city_object = self.session.execute(select(Model).where( - or_(Model.name == name, Model.aliases.contains(f'%{name}%')), Model.city_id == city_id - )).first() - return _city_object[0] + _city_objects = self.session.execute(select(Model).where( + or_(Model.name == name, Model.aliases.contains(f'{name}')), Model.city_id == city_id + )).all() + for city_object in _city_objects: + if city_object[0].name == name: + return city_object[0] + aliases = city_object[0].aliases.replace('{', '').replace('}', '').split(',') + for alias in aliases: + print(alias, name) + if alias == name: + # force the name as the alias + city_object[0].name = name + return city_object[0] + return None except SQLAlchemyError as err: logging.error('Error while fetching city object by name and city: %s', err) raise SQLAlchemyError from err diff --git a/tests/test_db_factory.py b/tests/test_db_factory.py index 5b826433..293f05d3 100644 --- a/tests/test_db_factory.py +++ b/tests/test_db_factory.py @@ -65,13 +65,13 @@ class Control: self._skip_test = True self._skip_reason = f'{operational_error}' return - + """ Application.__table__.create(bind=repository.engine, checkfirst=True) User.__table__.create(bind=repository.engine, checkfirst=True) City.__table__.create(bind=repository.engine, checkfirst=True) CityObject.__table__.create(bind=repository.engine, checkfirst=True) SimulationResults.__table__.create(bind=repository.engine, checkfirst=True) - + """ city_file = Path('tests_data/test.geojson').resolve() output_path = Path('tests_outputs/').resolve() self._city = GeometryFactory('geojson', @@ -103,11 +103,19 @@ class Control: app_env='TEST', dotenv_path=dotenv_path) - self._application_uuid = str(uuid.uuid4()) - self._application_id = self._database.persist_application('test', 'test application', self.application_uuid) - self._user_id = self._database.create_user('Admin', self._application_id, 'Admin@123', UserRoles.Admin) + self._application_uuid = '60b7fc1b-f389-4254-9ffd-22a4cf32c7a3' + self._application_id = 1 + self._user_id = 1 + + """ + self._application_id = self._database.persist_application( + 'City_layers', + 'City layers test user', + self.application_uuid + ) + self._user_id = self._database.create_user('city_layers', self._application_id, 'city_layers', UserRoles.Admin) + """ self._pickle_path = 'tests_data/pickle_path.bz2' - print('done') @property def database(self): @@ -182,6 +190,7 @@ TestDBFactory def test_get_update_city(self): city_id = control.database.persist_city(control.city, control.pickle_path, + control.city.name, control.application_id, control.user_id) control.city.name = "Ottawa" @@ -200,6 +209,7 @@ TestDBFactory def test_save_results(self): city_id = control.database.persist_city(control.city, control.pickle_path, + 'current status', control.application_id, control.user_id) city_objects_id = [] @@ -279,4 +289,4 @@ TestDBFactory def tearDownClass(cls): control.database.delete_application(control.application_uuid) control.database.delete_user(control.user_id) - """ \ No newline at end of file + """ \ No newline at end of file diff --git a/tests/tests_data/pickle_path.bz2 b/tests/tests_data/pickle_path.bz2 new file mode 100644 index 00000000..06e1d98f Binary files /dev/null and b/tests/tests_data/pickle_path.bz2 differ