diff --git a/README.md b/README.md index ab49c23..38b5a76 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +### Getting started + +To run a simple demo script you can use `uv` with: + + uv run src/realtime.py + ### API documentation ##### ZTM API docs diff --git a/scripts/realtime.py b/src/realtime.py similarity index 91% rename from scripts/realtime.py rename to src/realtime.py index 68e19c8..7590ca9 100755 --- a/scripts/realtime.py +++ b/src/realtime.py @@ -1,5 +1,6 @@ import zipfile import requests +import os from google.transit import gtfs_realtime_pb2 from dataclasses import dataclass @@ -40,6 +41,7 @@ class Stop: zone_id=row["zone_id"], ) + @dataclass class Trip: route_id: str @@ -77,6 +79,7 @@ class Trip: brigade=int(row["brigade"]), ) + @dataclass class StopTime: trip_id: str @@ -114,6 +117,7 @@ class StopTime: drop_off_type=int(row["drop_off_type"]), ) + # 2. Create a typed database wrapper class Database: conn: sqlite3.Connection @@ -228,32 +232,37 @@ class StopRepository: result = cursor.fetchone() return not result[0] # SQLite returns 0 if empty, 1 if not empty - def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]: + def get_upcoming_departures( + self, stop_id: int, limit: int = 10 + ) -> List[Tuple[StopTime, Trip]]: """Get upcoming departures from a stop with trip information.""" - cursor = self.db.execute(""" + cursor = self.db.execute( + """ SELECT st.*, t.* FROM stop_times st JOIN trips t ON st.trip_id = t.trip_id WHERE st.stop_id = ? ORDER BY st.departure_time LIMIT ? - """, (stop_id, limit)) - + """, + (stop_id, limit), + ) + departures = [] for row in cursor: stop_time = StopTime.from_row(row) trip = Trip( - route_id=row['route_id'], - service_id=row['service_id'], - trip_id=row['trip_id'], - trip_headsign=row['trip_headsign'], - direction_id=row['direction_id'], - shape_id=row['shape_id'], - wheelchair_accessible=row['wheelchair_accessible'], - brigade=row['brigade'] + route_id=row["route_id"], + service_id=row["service_id"], + trip_id=row["trip_id"], + trip_headsign=row["trip_headsign"], + direction_id=row["direction_id"], + shape_id=row["shape_id"], + wheelchair_accessible=row["wheelchair_accessible"], + brigade=row["brigade"], ) departures.append((stop_time, trip)) - + return departures @@ -344,14 +353,14 @@ class StopTimeRepository: def get_by_trip_id(self, trip_id: str) -> List[StopTime]: cursor = self.db.execute( "SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence", - (trip_id,) + (trip_id,), ) return [StopTime.from_row(row) for row in cursor] def get_by_stop_id(self, stop_id: int) -> List[StopTime]: cursor = self.db.execute( "SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time", - (stop_id,) + (stop_id,), ) return [StopTime.from_row(row) for row in cursor] @@ -360,6 +369,7 @@ class StopTimeRepository: result = cursor.fetchone() return not result[0] + class TripRepository: db: Database @@ -454,8 +464,7 @@ class TripRepository: def get_by_headsign(self, headsign: str) -> List[Trip]: cursor = self.db.execute( - "SELECT * FROM trips WHERE trip_headsign LIKE ?", - (f"%{headsign}%",) + "SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",) ) return [Trip.from_row(row) for row in cursor] @@ -464,14 +473,24 @@ class TripRepository: result = cursor.fetchone() return not result[0] -def download_data(): + +def download_static_data(): response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") - open("data.zip", "wb").write(response.content) - with zipfile.ZipFile("data.zip", "r") as zip_ref: + open("/tmp/data.zip", "wb").write(response.content) + with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref: zip_ref.extractall("tmp") + os.remove('/tmp/data.zip') + + +def download_realtime_data(): + response = requests.get( + "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb" + ) + open("tmp/trip_updates.pb", "wb").write(response.content) + def get_realtime(trip: Trip): - with open("trip_updates.pb", "rb") as f: + with open("tmp/trip_updates.pb", "rb") as f: response = f.read() feed = gtfs_realtime_pb2.FeedMessage() # type: ignore feed.ParseFromString(response) @@ -480,7 +499,11 @@ def get_realtime(trip: Trip): print(entity.id) print(entity.trip_update) + def main() -> None: + if not os.path.isdir("tmp"): + download_static_data() + download_realtime_data() with Database("transit.db") as db: stop_repo = StopRepository(db) stop_time_repo = StopTimeRepository(db) @@ -500,11 +523,11 @@ def main() -> None: trip_repo.load_from_csv(Path("tmp/trips.txt")) stop = stop_repo.find_by_name("Polanka")[0] - for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1): - print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}') + print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}") + for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5): + print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}") get_realtime(trip) - if __name__ == "__main__": main() diff --git a/scripts/timetable.py b/src/timetable.py similarity index 100% rename from scripts/timetable.py rename to src/timetable.py