diff --git a/functions/__Init__.py b/functions/__Init__.py index 14a466f..2aafaf5 100644 --- a/functions/__Init__.py +++ b/functions/__Init__.py @@ -1,13 +1,15 @@ from .population import process_buffer_population, process_travels from .network import process_nodes, process_links, process_links_attr -from .metro import process_metro, push_metro, metro_write -from .bus import process_bus, push_bus, bus_write -from .helpers import buffer_creator,push_to_db_coords,write_to_csv +from .metro import metro_processing +from .bus import bus_processing +from .helpers import file_validator,buffer_creator,push_to_db_coords,write_to_csv, push_to_db_linestring +from .printers import error_printer,success_printer,info_printer,notice_printer __all__ = [ - 'process_buffer_population', 'process_travels','push_to_db_coords', 'write_to_csv', - 'process_nodes', 'process_links', 'process_links_attr' - 'process_metro', 'push_metro', 'metro_write', - 'process_bus', 'push_bus', 'bus_write', - 'buffer_creator' + 'process_buffer_population', 'process_travels', + 'process_nodes', 'process_links', 'process_links_attr', + 'metro_processing', + 'bus_processing', + 'file_validator','buffer_creator','push_to_db_coords', 'write_to_csv', 'push_to_db_linestring' + 'error_printer','success_printer','info_printer', 'notice_printer', ] diff --git a/functions/bus.py b/functions/bus.py index e93620f..7955e85 100644 --- a/functions/bus.py +++ b/functions/bus.py @@ -1,9 +1,32 @@ +from .printers import error_printer, success_printer +import geopandas, typer -def process_bus(data, cleandata): - print(data, cleandata) +def bus_processing(city, files): + df_stm_arrets_sig = None + df_stm_lignes_sig = None -def push_bus(data,mode): - print(data,mode) + if city == "mtl": + required_files = ["stm_arrets_sig.shp", "stm_lignes_sig.shp"] + common_files = [file for file in files if file.name in required_files] + if len(common_files) == 0: + error_printer("Incorrect file input") + raise typer.Exit() + + for file in common_files: + if file.name == required_files[0]: + try: + df_arrets = geopandas.read_file(file) + df_arrets_filtered = df_arrets[~df_arrets['stop_url'].str.contains('metro', case=False, na=False)] + df_arrets_filtered = df_arrets_filtered.rename(columns={'geometry': 'coordinates'}) + df_stm_arrets_sig = df_arrets_filtered + except Exception as e: + error_printer(f"Failed to process stm_arrets_sig.shp: {e}") + elif file.name == required_files[1]: + try: + df_lignes = geopandas.read_file(file) + df_lignes_filtered = df_lignes[~df_lignes['route_name'].str.contains('Ligne', na=False)] + df_stm_lignes_sig = df_lignes_filtered + except Exception as e: + error_printer(f"Failed to process stm_lignes_sig.shp: {e}") -def bus_write(data): - print(data) \ No newline at end of file + return df_stm_arrets_sig, df_stm_lignes_sig \ No newline at end of file diff --git a/functions/helpers.py b/functions/helpers.py index ad4555a..8d50cd0 100644 --- a/functions/helpers.py +++ b/functions/helpers.py @@ -1,6 +1,25 @@ -import geopandas, os, datetime +import geopandas, os, datetime, typer from sqlalchemy import create_engine from geoalchemy2 import Geometry, WKTElement +from .printers import error_printer, success_printer + +def file_validator(file): + if not file.exists(): + error_printer("File not found") + raise typer.Exit() + try: + f = open(file, 'r', encoding='utf-8') + success_printer("File Opened") + except: + error_printer("Unable to read file") + raise typer.Exit() + count = sum(1 for _ in f) + if count == 0: + error_printer("File empty") + raise typer.Exit() + else: + success_printer(f"{count + 1} lines found") + f.close() def buffer_creator(file,divider,start_line, chunk_size): buffer = [] @@ -18,16 +37,40 @@ def buffer_creator(file,divider,start_line, chunk_size): buffer.append(line.strip()) return current_line,(' ').join(buffer) -def push_to_db_coords(name,data,mode): - GDF = geopandas.GeoDataFrame(data, crs='EPSG:4326') - GDF['geom'] = GDF['coordinates'].apply(lambda x: WKTElement(x.wkt, srid=os.getenv("SRID"))) - engine = create_engine(f'postgresql://{os.getenv("USER")}:{os.getenv("PASS")}@{os.getenv("HOST_NAME")}/{os.getenv("DATA_BASE")}', echo=False) +def push_to_db_coords(name, data, mode): + if not isinstance(data, geopandas.GeoDataFrame): + GDF = geopandas.GeoDataFrame(data, crs='EPSG:4326') + else: + GDF = data + GDF['geom'] = GDF['coordinates'].apply(lambda x: WKTElement(x.wkt, srid=int(os.getenv("SRID")))) + engine = create_engine( + f'postgresql://{os.getenv("USER")}:{os.getenv("PASS")}@{os.getenv("HOST_NAME")}/{os.getenv("DATA_BASE")}', + echo=False + ) GDF.to_sql( name=name, con=engine, if_exists=mode, - chunksize=os.getenv("CHUNK_SIZE"), - dtype={'geom': Geometry('Point', srid=os.getenv("SRID"))}, + chunksize=int(os.getenv("CHUNK_SIZE")), + dtype={'geom': Geometry('Point', srid=int(os.getenv("SRID")))}, + index=False + ) + +def push_to_db_linestring(name, data, mode): + + data = data.to_crs('EPSG:4326') + data['geom'] = data['geometry'].apply(lambda x: WKTElement(x.wkt, srid=int(os.getenv("SRID")))) + data = data.drop(columns=['geometry']) + engine = create_engine( + f'postgresql://{os.getenv("USER")}:{os.getenv("PASS")}@{os.getenv("HOST_NAME")}/{os.getenv("DATA_BASE")}', + echo=False + ) + data.to_sql( + name=name, + con=engine, + if_exists=mode, + chunksize=int(os.getenv("CHUNK_SIZE")), + dtype={'geom': Geometry('LINESTRING', srid=int(os.getenv("SRID")))}, index=False ) diff --git a/functions/metro.py b/functions/metro.py index c5f4f98..2aa46ff 100644 --- a/functions/metro.py +++ b/functions/metro.py @@ -1,8 +1,34 @@ -def process_metro(data, cleandata): - print(data, cleandata) +from .printers import error_printer, success_printer +import geopandas, typer -def push_metro(data,mode): - print(data,mode) +def metro_processing(city, files): -def metro_write(data): - print(data) \ No newline at end of file + df_stm_arrets_sig = None + df_stm_lignes_sig = None + + if city == "mtl": + required_files = ["stm_arrets_sig.shp", "stm_lignes_sig.shp"] + common_files = [file for file in files if file.stem in required_files] + + if len(common_files) == 0: + error_printer("Incorrect file input") + raise typer.Exit() + + for file in common_files: + if file == required_files[0]: + try: + df_arrets = geopandas.read_file(file) + df_arrets_filtered = df_arrets[df_arrets['stop_url'].str.contains('metro', case=False, na=False)] + df_arrets_filtered = df_arrets_filtered.rename(columns={'geometry': 'coordinates'}) + df_stm_arrets_sig = df_arrets_filtered + except Exception as e: + error_printer(f"Failed to process stm_arrets_sig.shp: {e}") + elif file == required_files[1]: + try: + df_lignes = geopandas.read_file(file) + df_lignes_filtered = df_lignes[df_lignes['route_name'].str.contains('Ligne', na=False)] + df_stm_lignes_sig = df_lignes_filtered + except Exception as e: + error_printer(f"Failed to process stm_lignes_sig.shp: {e}") + + return df_stm_arrets_sig, df_stm_lignes_sig \ No newline at end of file diff --git a/functions/printers.py b/functions/printers.py new file mode 100644 index 0000000..35ab5d2 --- /dev/null +++ b/functions/printers.py @@ -0,0 +1,10 @@ +from rich import print + +def error_printer(text): + print(f'[bold red]ERROR:[/bold red] [bold]{text}[/bold]') +def success_printer(text): + print(f'[bold green]SUCCESS:[/bold green] [bold]{text}[/bold]') +def info_printer(text): + print(f'[bold blue]INFO:[/bold blue] [bold]{text}[/bold]') +def notice_printer(text): + print(f'[bold yellow]NOTICE:[/bold yellow] [bold]{text}[/bold]') \ No newline at end of file diff --git a/main.py b/main.py index 1f240f7..786263b 100644 --- a/main.py +++ b/main.py @@ -11,25 +11,16 @@ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskPr import time from classes import City, DBMode, RTMode -from functions import buffer_creator, process_buffer_population,process_travels, push_to_db_coords, write_to_csv +from functions import file_validator,buffer_creator, process_buffer_population,process_travels, push_to_db_coords, write_to_csv, push_to_db_linestring from functions import process_nodes,process_links,process_links_attr +from functions import error_printer,success_printer,info_printer,notice_printer +from functions import metro_processing, bus_processing from styles import print_help -import xml.etree.ElementTree as ET - called= "population" app = typer.Typer(rich_markup_mode="rich") load_dotenv() -def error_printer(text): - print(f'[bold red]ERROR:[/bold red] [bold]{text}[/bold]') -def success_printer(text): - print(f'[bold green]SUCCESS:[/bold green] [bold]{text}[/bold]') -def info_printer(text): - print(f'[bold blue]INFO:[/bold blue] [bold]{text}[/bold]') -def notice_printer(text): - print(f'[bold yellow]NOTICE:[/bold yellow] [bold]{text}[/bold]') - @app.command(print_help()) def population( file: Annotated[Path, typer.Argument(help="Provide the relative path to the [yellow bold underline]XML file[/yellow bold underline].", show_default=False)], @@ -49,24 +40,7 @@ def population( elif "all" in common_tables: common_tables = all_tables info_printer(f"Tables to inlude: {common_tables}") - if not file.exists(): - error_printer("File not found") - raise typer.Exit() - try: - f = open(file, 'r', encoding='utf-8') - success_printer("File Opened") - except: - error_printer("Unable to read file") - raise typer.Exit() - - count = sum(1 for _ in f) - if count == 0: - error_printer("File empty") - raise typer.Exit() - else: - success_printer(f"{count + 1} lines read") - f.close() - max_chunk = 0 + file_validator(file) with open(file,'r',encoding='utf-8') as f: for line in f: if line.strip() == os.getenv("DIVIDER"): @@ -157,23 +131,7 @@ def network( elif "all" in common_tables: common_tables = all_tables info_printer(f"Tables to inlude: {common_tables}") - if not file.exists(): - error_printer("File not found") - raise typer.Exit() - try: - f = open(file, 'r', encoding='utf-8') - success_printer("File Opened") - except: - error_printer("Unable to read file") - raise typer.Exit() - - count = sum(1 for _ in f) - if count == 0: - error_printer("File empty") - raise typer.Exit() - else: - success_printer(f"{count + 1} lines found") - f.close() + file_validator(file) BUFFER = [] DEVIDER_COUNT = 0 with Progress( @@ -261,18 +219,71 @@ def network( @app.command() def metro( city: Annotated[City, typer.Argument(..., help="Choose a city", show_default=False)], - mode: Annotated[RTMode, typer.Argument(..., help="Choose a city", show_default=False)], - address: Annotated[str, typer.Argument(..., help="enter a relative path or URL", show_default=False)], - ): - print(f"Hello {city}") + files: list[Path] = typer.Option(None, "--files", "-f", help="Provide the relative path to [yellow bold underline]shape files[/yellow bold underline].", show_default=False), + cleandata: bool = typer.Option(False, "--cleandata", "-cd", help="Drop the rows that have missing values."), + push: bool = typer.Option(False, "--push", "-p", help="Save the output directly to the database When mentioned. Otherwise, saves as a [green bold]CSV file[/green bold] in the input directory"), + pushmode: Optional[DBMode] = typer.Option(None, help="Specify either [underline]'append'[/underline] or [underline]'drop'[/underline] when pushing data", show_default=False), +): + for file in files: + if not file.exists(): + error_printer(f"Shapefile {file} does not exist.") + raise typer.Exit() + if file.suffix != '.shp': + error_printer(f"File {file} is not a .shp file.") + raise typer.Exit() + success_printer("Shapefiles validated successfully.") + metro_stations_df, metro_lines_df = metro_processing(city, files) + if not metro_stations_df or not metro_lines_df: + error_printer("dataframes were processed successfully") + raise typer.Exit() + if cleandata: + if metro_stations_df: metro_stations_df = metro_stations_df.dropna() + if metro_lines_df: metro_lines_df = metro_lines_df.dropna() + if push: + if metro_stations_df: push_to_db_coords("metro-stations",metro_stations_df,pushmode) + if metro_lines_df: push_to_db_linestring("metro-lines",metro_lines_df, pushmode) + else: + if metro_stations_df: write_to_csv("metro-stations",metro_stations_df,file) + if metro_lines_df: write_to_csv("metro-lines",metro_lines_df,file) + success_printer("Processing complete.") @app.command() def bus( city: Annotated[City, typer.Argument(..., help="Choose a city", show_default=False)], - mode: Annotated[RTMode, typer.Argument(..., help="Choose a city", show_default=False)], - address: Annotated[str, typer.Argument(..., help="enter a relative path or URL", show_default=False)], + files: list[Path] = typer.Option(None, "--files", "-f", help="Provide the relative path to [yellow bold underline]shape files[/yellow bold underline].", show_default=False), + cleandata: bool = typer.Option(False, "--cleandata", "-cd", help="Drop the rows that have missing values."), + push: bool = typer.Option(False, "--push", "-p", help="Save the output directly to the database when mentioned. Otherwise, saves as a [green bold]CSV file[/green bold] in the input directory"), + pushmode: Optional[DBMode] = typer.Option(None, help="Specify either [underline]'append'[/underline] or [underline]'drop'[/underline] when pushing data", show_default=False), ): - print(f"Hello {city}") + for file in files: + if not file.exists(): + error_printer(f"Shapefile {file} does not exist.") + raise typer.Exit() + if file.suffix != '.shp': + error_printer(f"File {file} is not a .shp file.") + raise typer.Exit() + success_printer("Shapefiles validated successfully.") + bus_stations_df, bus_lines_df = bus_processing(city, files) + if not bus_stations_df and not bus_lines_df: + error_printer("No dataframes were processed successfully.") + raise typer.Exit() + if cleandata: + if bus_stations_df is not None: + bus_stations_df = bus_stations_df.dropna() + if bus_lines_df is not None: + bus_lines_df = bus_lines_df.dropna() + if push: + if bus_stations_df is not None: + push_to_db_coords("bus-stations", bus_stations_df, pushmode) + if bus_lines_df is not None: + push_to_db_linestring("bus-lines", bus_lines_df, pushmode) + else: + if bus_stations_df is not None: + write_to_csv("bus-stations", bus_stations_df, file) + if bus_lines_df is not None: + write_to_csv("bus-lines", bus_lines_df, file) + success_printer("Processing complete.") + if __name__ == "__main__": app() \ No newline at end of file