diff --git a/cerc_persistence/repositories/city_object.py b/cerc_persistence/repositories/city_object.py index 9e33b8c..6fa720c 100644 --- a/cerc_persistence/repositories/city_object.py +++ b/cerc_persistence/repositories/city_object.py @@ -136,7 +136,7 @@ class CityObject(Repository): logging.error('Error while fetching city object by name and city, empty result %s', err) raise IndexError from err - def get_by_name_or_alias_for_user_app(self, user_id, application_id, names) -> Union[Model, None]: + def get_by_name_or_alias_for_user_app(self, user_id, application_id, names) -> [Model]: """ Fetch city objects belonging to the user and application where the name or alias is in the names list :param user_id: User ID @@ -144,17 +144,24 @@ class CityObject(Repository): :param names: a list of building aliases or names :return [CityObject] or None """ - with Session(self.engine) as session: - cities = session.execute(select(CityModel).where( - CityModel.user_id == user_id, CityModel.application_id == application_id - )).all() - ids = [c[0].id for c in cities] - buildings = session.execute(select(Model).where( - Model.city_id.in_(ids), Model.name.in_(names) - )) - results = [r[0] for r in buildings] - print(ids, buildings) - return None + try: + results = [] + with Session(self.engine) as session: + cities = session.execute(select(CityModel).where( + CityModel.user_id == user_id, CityModel.application_id == application_id + )).all() + ids = [c[0].id for c in cities] + buildings = session.execute(select(Model).where( + Model.city_id.in_(ids), Model.name.in_(names) + )) + results = [r[0] for r in buildings] + return results + except SQLAlchemyError as err: + logging.error('Error while fetching city object by name and city: %s', err) + raise SQLAlchemyError from err + except IndexError as err: + logging.error('Error while fetching city object by name and city, empty result %s', err) + raise IndexError from err def get_by_name_or_alias_and_city(self, name, city_id) -> Union[Model, None]: """ diff --git a/tests/test_db_factory.py b/tests/test_db_factory.py index 24e140a..9320677 100644 --- a/tests/test_db_factory.py +++ b/tests/test_db_factory.py @@ -48,6 +48,7 @@ class Control: # Create test database. dotenv_path = Path("{}/.local/etc/hub/.env".format(os.path.expanduser('~'))).resolve() + print(dotenv_path) if not dotenv_path.exists(): self._skip_test = True self._skip_reason = f'.env file missing at {dotenv_path}' @@ -397,9 +398,10 @@ TestDBFactory city_object = control.database.building_info_in_cities(test_building.name, [mtl_city_id, ott_city_id]) self.assertEqual(city_object.name, test_building.name, "City name does not match") - # TODO: Waiting for the code to be finished to complete this test. city_objects = control.database.buildings_info(control.user_id, control.application_id, [control.city.buildings[0].name, control.city.buildings[1].name]) - # for city_obj in city_objects: + self.assertEqual(len(city_objects), 4, "Found {} city objects but expected 4".format(len(city_objects))) + for city_obj in city_objects: + self.assertTrue(city_obj.name == control.city.buildings[0].name or city_obj.name == control.city.buildings[1].name, "City object name does not match any expected value. Obtained {} instead.".format(city_obj.name)) control.database.delete_city(mtl_city_id) control.database.delete_city(ott_city_id) @@ -439,7 +441,7 @@ TestDBFactory city_objects = control.database.buildings_info(control.user_id, control.application_id, [control.city.buildings[0].name, control.city.buildings[1].name]) - self.assertEqual(len(city_objects), 0, "Found a city object with given values"); + self.assertEqual(len(city_objects), 0, "Found a city object with given values") @classmethod @unittest.skipIf(control.skip_test, control.skip_reason)