import zipfile import requests from google.transit import gtfs_realtime_pb2 from dataclasses import dataclass from typing import List, Optional, Iterator, Tuple import sqlite3 import csv from pathlib import Path @dataclass class Stop: stop_id: int stop_code: str stop_name: str stop_lat: float stop_lon: float zone_id: str @classmethod def from_row(cls, row: sqlite3.Row) -> "Stop": return cls( stop_id=row["stop_id"], stop_code=row["stop_code"], stop_name=row["stop_name"], stop_lat=row["stop_lat"], stop_lon=row["stop_lon"], zone_id=row["zone_id"], ) @classmethod def from_csv_row(cls, row) -> "Stop": return cls( stop_id=int(row["stop_id"]), stop_code=row["stop_code"].strip('"'), stop_name=row["stop_name"].strip('"'), stop_lat=float(row["stop_lat"]), stop_lon=float(row["stop_lon"]), zone_id=row["zone_id"], ) @dataclass class Trip: route_id: str service_id: str trip_id: str trip_headsign: str direction_id: int shape_id: int wheelchair_accessible: int brigade: int @classmethod def from_row(cls, row: sqlite3.Row) -> "Trip": return cls( route_id=row["route_id"], service_id=row["service_id"], trip_id=row["trip_id"], trip_headsign=row["trip_headsign"], direction_id=row["direction_id"], shape_id=row["shape_id"], wheelchair_accessible=row["wheelchair_accessible"], brigade=row["brigade"], ) @classmethod def from_csv_row(cls, row) -> "Trip": return cls( route_id=row["route_id"].strip('"'), service_id=row["service_id"].strip('"'), trip_id=row["trip_id"].strip('"'), trip_headsign=row["trip_headsign"].strip('"'), direction_id=int(row["direction_id"]), shape_id=int(row["shape_id"]), wheelchair_accessible=int(row["wheelchair_accessible"]), brigade=int(row["brigade"]), ) @dataclass class StopTime: trip_id: str arrival_time: str # Using str for now since SQLite doesn't have a TIME type departure_time: str stop_id: int stop_sequence: int stop_headsign: str pickup_type: int drop_off_type: int @classmethod def from_row(cls, row: sqlite3.Row) -> "StopTime": return cls( trip_id=row["trip_id"], arrival_time=row["arrival_time"], departure_time=row["departure_time"], stop_id=row["stop_id"], stop_sequence=row["stop_sequence"], stop_headsign=row["stop_headsign"], pickup_type=row["pickup_type"], drop_off_type=row["drop_off_type"], ) @classmethod def from_csv_row(cls, row) -> "StopTime": return cls( trip_id=row["trip_id"].strip('"'), arrival_time=row["arrival_time"], departure_time=row["departure_time"], stop_id=int(row["stop_id"]), stop_sequence=int(row["stop_sequence"]), stop_headsign=row["stop_headsign"].strip('"'), pickup_type=int(row["pickup_type"]), drop_off_type=int(row["drop_off_type"]), ) # 2. Create a typed database wrapper class Database: conn: sqlite3.Connection def __init__(self, db_path: str): self.db_path = db_path def __enter__(self) -> "Database": self.conn = sqlite3.connect(self.db_path) self.conn.row_factory = sqlite3.Row return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.conn.close() def execute(self, query: str, params={}) -> sqlite3.Cursor: return self.conn.execute(query, params) def executemany(self, query: str, params_list=[]) -> sqlite3.Cursor: return self.conn.executemany(query, params_list) def commit(self) -> None: self.conn.commit() # 3. Create a typed repository for stops class StopRepository: db: Database def __init__(self, db: Database): self.db = db def create_table(self) -> None: self.db.execute(""" CREATE TABLE IF NOT EXISTS stops ( stop_id INTEGER PRIMARY KEY, stop_code TEXT NOT NULL, stop_name TEXT NOT NULL, stop_lat REAL NOT NULL, stop_lon REAL NOT NULL, zone_id TEXT NOT NULL ) """) self.db.commit() def insert(self, stop: Stop) -> None: self.db.execute( """ INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id) VALUES (?, ?, ?, ?, ?, ?) """, ( stop.stop_id, stop.stop_code, stop.stop_name, stop.stop_lat, stop.stop_lon, stop.zone_id, ), ) self.db.commit() def bulk_insert(self, stops: list[Stop]) -> None: """Insert multiple stops efficiently.""" self.db.executemany( """ INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id) VALUES (?, ?, ?, ?, ?, ?) """, [ (s.stop_id, s.stop_code, s.stop_name, s.stop_lat, s.stop_lon, s.zone_id) for s in stops ], ) self.db.commit() def load_from_csv(self, csv_path: Path) -> None: """Load stops from a CSV file.""" stops: list[Stop] = [] with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: reader = csv.DictReader(csvfile) for row in reader: try: stop = Stop.from_csv_row(row) stops.append(stop) except (ValueError, KeyError) as e: print(f"Error processing row: {row}. Error: {e}") if stops: self.bulk_insert(stops) def get_by_id(self, stop_id: int) -> Optional[Stop]: cursor = self.db.execute("SELECT * FROM stops WHERE stop_id = ?", (stop_id,)) row = cursor.fetchone() return Stop.from_row(row) if row else None def get_all(self) -> Iterator[Stop]: cursor = self.db.execute("SELECT * FROM stops") for row in cursor: yield Stop.from_row(row) def find_by_name(self, name: str) -> List[Stop]: cursor = self.db.execute( "SELECT * FROM stops WHERE stop_name LIKE ?", (f"%{name}%",) ) return [Stop.from_row(row) for row in cursor] def is_empty(self) -> bool: """Check if the stops table is empty.""" cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stops LIMIT 1)") result = cursor.fetchone() return not result[0] # SQLite returns 0 if empty, 1 if not empty def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]: """Get upcoming departures from a stop with trip information.""" cursor = self.db.execute(""" SELECT st.*, t.* FROM stop_times st JOIN trips t ON st.trip_id = t.trip_id WHERE st.stop_id = ? ORDER BY st.departure_time LIMIT ? """, (stop_id, limit)) departures = [] for row in cursor: stop_time = StopTime.from_row(row) trip = Trip( route_id=row['route_id'], service_id=row['service_id'], trip_id=row['trip_id'], trip_headsign=row['trip_headsign'], direction_id=row['direction_id'], shape_id=row['shape_id'], wheelchair_accessible=row['wheelchair_accessible'], brigade=row['brigade'] ) departures.append((stop_time, trip)) return departures class StopTimeRepository: db: Database def __init__(self, db: Database): self.db = db def create_table(self) -> None: self.db.execute(""" CREATE TABLE IF NOT EXISTS stop_times ( trip_id TEXT NOT NULL, arrival_time TEXT NOT NULL, departure_time TEXT NOT NULL, stop_id INTEGER NOT NULL, stop_sequence INTEGER NOT NULL, stop_headsign TEXT NOT NULL, pickup_type INTEGER NOT NULL, drop_off_type INTEGER NOT NULL, FOREIGN KEY (stop_id) REFERENCES stops (stop_id) ) """) self.db.commit() def insert(self, stop_time: StopTime) -> None: self.db.execute( """ INSERT INTO stop_times ( trip_id, arrival_time, departure_time, stop_id, stop_sequence, stop_headsign, pickup_type, drop_off_type ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( stop_time.trip_id, stop_time.arrival_time, stop_time.departure_time, stop_time.stop_id, stop_time.stop_sequence, stop_time.stop_headsign, stop_time.pickup_type, stop_time.drop_off_type, ), ) self.db.commit() def bulk_insert(self, stop_times: list[StopTime]) -> None: self.db.executemany( """ INSERT INTO stop_times ( trip_id, arrival_time, departure_time, stop_id, stop_sequence, stop_headsign, pickup_type, drop_off_type ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, [ ( st.trip_id, st.arrival_time, st.departure_time, st.stop_id, st.stop_sequence, st.stop_headsign, st.pickup_type, st.drop_off_type, ) for st in stop_times ], ) self.db.commit() def load_from_csv(self, csv_path: Path) -> None: stop_times: list[StopTime] = [] with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: reader = csv.DictReader(csvfile) for row in reader: try: stop_time = StopTime.from_csv_row(row) stop_times.append(stop_time) except (ValueError, KeyError) as e: print(f"Error processing row: {row}. Error: {e}") if stop_times: self.bulk_insert(stop_times) def get_by_trip_id(self, trip_id: str) -> List[StopTime]: cursor = self.db.execute( "SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence", (trip_id,) ) return [StopTime.from_row(row) for row in cursor] def get_by_stop_id(self, stop_id: int) -> List[StopTime]: cursor = self.db.execute( "SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time", (stop_id,) ) return [StopTime.from_row(row) for row in cursor] def is_empty(self) -> bool: cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stop_times LIMIT 1)") result = cursor.fetchone() return not result[0] class TripRepository: db: Database def __init__(self, db: Database): self.db = db def create_table(self) -> None: self.db.execute(""" CREATE TABLE IF NOT EXISTS trips ( route_id TEXT NOT NULL, service_id TEXT NOT NULL, trip_id TEXT PRIMARY KEY, trip_headsign TEXT NOT NULL, direction_id INTEGER NOT NULL, shape_id INTEGER NOT NULL, wheelchair_accessible INTEGER NOT NULL, brigade INTEGER NOT NULL ) """) self.db.commit() def insert(self, trip: Trip) -> None: self.db.execute( """ INSERT INTO trips ( route_id, service_id, trip_id, trip_headsign, direction_id, shape_id, wheelchair_accessible, brigade ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( trip.route_id, trip.service_id, trip.trip_id, trip.trip_headsign, trip.direction_id, trip.shape_id, trip.wheelchair_accessible, trip.brigade, ), ) self.db.commit() def bulk_insert(self, trips: list[Trip]) -> None: self.db.executemany( """ INSERT INTO trips ( route_id, service_id, trip_id, trip_headsign, direction_id, shape_id, wheelchair_accessible, brigade ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, [ ( t.route_id, t.service_id, t.trip_id, t.trip_headsign, t.direction_id, t.shape_id, t.wheelchair_accessible, t.brigade, ) for t in trips ], ) self.db.commit() def load_from_csv(self, csv_path: Path) -> None: trips: list[Trip] = [] with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: reader = csv.DictReader(csvfile) for row in reader: try: trip = Trip.from_csv_row(row) trips.append(trip) except (ValueError, KeyError) as e: print(f"Error processing row: {row}. Error: {e}") if trips: self.bulk_insert(trips) def get_by_id(self, trip_id: str) -> Optional[Trip]: cursor = self.db.execute("SELECT * FROM trips WHERE trip_id = ?", (trip_id,)) row = cursor.fetchone() return Trip.from_row(row) if row else None def get_by_route_id(self, route_id: str) -> List[Trip]: cursor = self.db.execute("SELECT * FROM trips WHERE route_id = ?", (route_id,)) return [Trip.from_row(row) for row in cursor] def get_by_headsign(self, headsign: str) -> List[Trip]: cursor = self.db.execute( "SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",) ) return [Trip.from_row(row) for row in cursor] def is_empty(self) -> bool: cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM trips LIMIT 1)") result = cursor.fetchone() return not result[0] def download_data(): response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") open("data.zip", "wb").write(response.content) with zipfile.ZipFile("data.zip", "r") as zip_ref: zip_ref.extractall("tmp") def get_realtime(trip: Trip): with open("trip_updates.pb", "rb") as f: response = f.read() feed = gtfs_realtime_pb2.FeedMessage() # type: ignore feed.ParseFromString(response) for entity in feed.entity: if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id): print(entity.id) print(entity.trip_update) def main() -> None: with Database("transit.db") as db: stop_repo = StopRepository(db) stop_time_repo = StopTimeRepository(db) trip_repo = TripRepository(db) # Create tables stop_repo.create_table() stop_time_repo.create_table() trip_repo.create_table() # Load data if tables are empty if stop_repo.is_empty(): stop_repo.load_from_csv(Path("tmp/stops.txt")) if stop_time_repo.is_empty(): stop_time_repo.load_from_csv(Path("tmp/stop_times.txt")) if trip_repo.is_empty(): trip_repo.load_from_csv(Path("tmp/trips.txt")) stop = stop_repo.find_by_name("Polanka")[0] for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1): print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}') get_realtime(trip) if __name__ == "__main__": main()