refactor, download if needed

This commit is contained in:
Dawid Pietrykowski 2025-02-14 09:37:53 +01:00
parent 947fa37aac
commit 2711d13670
3 changed files with 53 additions and 24 deletions

View File

@ -1,3 +1,9 @@
### Getting started
To run a simple demo script you can use `uv` with:
uv run src/realtime.py
### API documentation ### API documentation
##### ZTM API docs ##### ZTM API docs

View File

@ -1,5 +1,6 @@
import zipfile import zipfile
import requests import requests
import os
from google.transit import gtfs_realtime_pb2 from google.transit import gtfs_realtime_pb2
from dataclasses import dataclass from dataclasses import dataclass
@ -40,6 +41,7 @@ class Stop:
zone_id=row["zone_id"], zone_id=row["zone_id"],
) )
@dataclass @dataclass
class Trip: class Trip:
route_id: str route_id: str
@ -77,6 +79,7 @@ class Trip:
brigade=int(row["brigade"]), brigade=int(row["brigade"]),
) )
@dataclass @dataclass
class StopTime: class StopTime:
trip_id: str trip_id: str
@ -114,6 +117,7 @@ class StopTime:
drop_off_type=int(row["drop_off_type"]), drop_off_type=int(row["drop_off_type"]),
) )
# 2. Create a typed database wrapper # 2. Create a typed database wrapper
class Database: class Database:
conn: sqlite3.Connection conn: sqlite3.Connection
@ -228,32 +232,37 @@ class StopRepository:
result = cursor.fetchone() result = cursor.fetchone()
return not result[0] # SQLite returns 0 if empty, 1 if not empty return not result[0] # SQLite returns 0 if empty, 1 if not empty
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]: def get_upcoming_departures(
self, stop_id: int, limit: int = 10
) -> List[Tuple[StopTime, Trip]]:
"""Get upcoming departures from a stop with trip information.""" """Get upcoming departures from a stop with trip information."""
cursor = self.db.execute(""" cursor = self.db.execute(
"""
SELECT st.*, t.* SELECT st.*, t.*
FROM stop_times st FROM stop_times st
JOIN trips t ON st.trip_id = t.trip_id JOIN trips t ON st.trip_id = t.trip_id
WHERE st.stop_id = ? WHERE st.stop_id = ?
ORDER BY st.departure_time ORDER BY st.departure_time
LIMIT ? LIMIT ?
""", (stop_id, limit)) """,
(stop_id, limit),
)
departures = [] departures = []
for row in cursor: for row in cursor:
stop_time = StopTime.from_row(row) stop_time = StopTime.from_row(row)
trip = Trip( trip = Trip(
route_id=row['route_id'], route_id=row["route_id"],
service_id=row['service_id'], service_id=row["service_id"],
trip_id=row['trip_id'], trip_id=row["trip_id"],
trip_headsign=row['trip_headsign'], trip_headsign=row["trip_headsign"],
direction_id=row['direction_id'], direction_id=row["direction_id"],
shape_id=row['shape_id'], shape_id=row["shape_id"],
wheelchair_accessible=row['wheelchair_accessible'], wheelchair_accessible=row["wheelchair_accessible"],
brigade=row['brigade'] brigade=row["brigade"],
) )
departures.append((stop_time, trip)) departures.append((stop_time, trip))
return departures return departures
@ -344,14 +353,14 @@ class StopTimeRepository:
def get_by_trip_id(self, trip_id: str) -> List[StopTime]: def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
cursor = self.db.execute( cursor = self.db.execute(
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence", "SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
(trip_id,) (trip_id,),
) )
return [StopTime.from_row(row) for row in cursor] return [StopTime.from_row(row) for row in cursor]
def get_by_stop_id(self, stop_id: int) -> List[StopTime]: def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
cursor = self.db.execute( cursor = self.db.execute(
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time", "SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
(stop_id,) (stop_id,),
) )
return [StopTime.from_row(row) for row in cursor] return [StopTime.from_row(row) for row in cursor]
@ -360,6 +369,7 @@ class StopTimeRepository:
result = cursor.fetchone() result = cursor.fetchone()
return not result[0] return not result[0]
class TripRepository: class TripRepository:
db: Database db: Database
@ -454,8 +464,7 @@ class TripRepository:
def get_by_headsign(self, headsign: str) -> List[Trip]: def get_by_headsign(self, headsign: str) -> List[Trip]:
cursor = self.db.execute( cursor = self.db.execute(
"SELECT * FROM trips WHERE trip_headsign LIKE ?", "SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",)
(f"%{headsign}%",)
) )
return [Trip.from_row(row) for row in cursor] return [Trip.from_row(row) for row in cursor]
@ -464,14 +473,24 @@ class TripRepository:
result = cursor.fetchone() result = cursor.fetchone()
return not result[0] return not result[0]
def download_data():
def download_static_data():
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
open("data.zip", "wb").write(response.content) open("/tmp/data.zip", "wb").write(response.content)
with zipfile.ZipFile("data.zip", "r") as zip_ref: with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
zip_ref.extractall("tmp") zip_ref.extractall("tmp")
os.remove('/tmp/data.zip')
def download_realtime_data():
response = requests.get(
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
)
open("tmp/trip_updates.pb", "wb").write(response.content)
def get_realtime(trip: Trip): def get_realtime(trip: Trip):
with open("trip_updates.pb", "rb") as f: with open("tmp/trip_updates.pb", "rb") as f:
response = f.read() response = f.read()
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
feed.ParseFromString(response) feed.ParseFromString(response)
@ -480,7 +499,11 @@ def get_realtime(trip: Trip):
print(entity.id) print(entity.id)
print(entity.trip_update) print(entity.trip_update)
def main() -> None: def main() -> None:
if not os.path.isdir("tmp"):
download_static_data()
download_realtime_data()
with Database("transit.db") as db: with Database("transit.db") as db:
stop_repo = StopRepository(db) stop_repo = StopRepository(db)
stop_time_repo = StopTimeRepository(db) stop_time_repo = StopTimeRepository(db)
@ -500,11 +523,11 @@ def main() -> None:
trip_repo.load_from_csv(Path("tmp/trips.txt")) trip_repo.load_from_csv(Path("tmp/trips.txt"))
stop = stop_repo.find_by_name("Polanka")[0] stop = stop_repo.find_by_name("Polanka")[0]
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1): print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}') for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
get_realtime(trip) get_realtime(trip)
if __name__ == "__main__": if __name__ == "__main__":
main() main()