refactor, download if needed

This commit is contained in:
Dawid Pietrykowski 2025-02-14 09:37:53 +01:00
parent 947fa37aac
commit 2711d13670
3 changed files with 53 additions and 24 deletions

View File

@ -1,3 +1,9 @@
### Getting started
To run a simple demo script you can use `uv` with:
uv run src/realtime.py
### API documentation
##### ZTM API docs

View File

@ -1,5 +1,6 @@
import zipfile
import requests
import os
from google.transit import gtfs_realtime_pb2
from dataclasses import dataclass
@ -40,6 +41,7 @@ class Stop:
zone_id=row["zone_id"],
)
@dataclass
class Trip:
route_id: str
@ -77,6 +79,7 @@ class Trip:
brigade=int(row["brigade"]),
)
@dataclass
class StopTime:
trip_id: str
@ -114,6 +117,7 @@ class StopTime:
drop_off_type=int(row["drop_off_type"]),
)
# 2. Create a typed database wrapper
class Database:
conn: sqlite3.Connection
@ -228,32 +232,37 @@ class StopRepository:
result = cursor.fetchone()
return not result[0] # SQLite returns 0 if empty, 1 if not empty
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
def get_upcoming_departures(
self, stop_id: int, limit: int = 10
) -> List[Tuple[StopTime, Trip]]:
"""Get upcoming departures from a stop with trip information."""
cursor = self.db.execute("""
cursor = self.db.execute(
"""
SELECT st.*, t.*
FROM stop_times st
JOIN trips t ON st.trip_id = t.trip_id
WHERE st.stop_id = ?
ORDER BY st.departure_time
LIMIT ?
""", (stop_id, limit))
""",
(stop_id, limit),
)
departures = []
for row in cursor:
stop_time = StopTime.from_row(row)
trip = Trip(
route_id=row['route_id'],
service_id=row['service_id'],
trip_id=row['trip_id'],
trip_headsign=row['trip_headsign'],
direction_id=row['direction_id'],
shape_id=row['shape_id'],
wheelchair_accessible=row['wheelchair_accessible'],
brigade=row['brigade']
route_id=row["route_id"],
service_id=row["service_id"],
trip_id=row["trip_id"],
trip_headsign=row["trip_headsign"],
direction_id=row["direction_id"],
shape_id=row["shape_id"],
wheelchair_accessible=row["wheelchair_accessible"],
brigade=row["brigade"],
)
departures.append((stop_time, trip))
return departures
@ -344,14 +353,14 @@ class StopTimeRepository:
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
cursor = self.db.execute(
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
(trip_id,)
(trip_id,),
)
return [StopTime.from_row(row) for row in cursor]
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
cursor = self.db.execute(
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
(stop_id,)
(stop_id,),
)
return [StopTime.from_row(row) for row in cursor]
@ -360,6 +369,7 @@ class StopTimeRepository:
result = cursor.fetchone()
return not result[0]
class TripRepository:
db: Database
@ -454,8 +464,7 @@ class TripRepository:
def get_by_headsign(self, headsign: str) -> List[Trip]:
cursor = self.db.execute(
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
(f"%{headsign}%",)
"SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",)
)
return [Trip.from_row(row) for row in cursor]
@ -464,14 +473,24 @@ class TripRepository:
result = cursor.fetchone()
return not result[0]
def download_data():
def download_static_data():
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
open("data.zip", "wb").write(response.content)
with zipfile.ZipFile("data.zip", "r") as zip_ref:
open("/tmp/data.zip", "wb").write(response.content)
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
zip_ref.extractall("tmp")
os.remove('/tmp/data.zip')
def download_realtime_data():
response = requests.get(
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
)
open("tmp/trip_updates.pb", "wb").write(response.content)
def get_realtime(trip: Trip):
with open("trip_updates.pb", "rb") as f:
with open("tmp/trip_updates.pb", "rb") as f:
response = f.read()
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
feed.ParseFromString(response)
@ -480,7 +499,11 @@ def get_realtime(trip: Trip):
print(entity.id)
print(entity.trip_update)
def main() -> None:
if not os.path.isdir("tmp"):
download_static_data()
download_realtime_data()
with Database("transit.db") as db:
stop_repo = StopRepository(db)
stop_time_repo = StopTimeRepository(db)
@ -500,11 +523,11 @@ def main() -> None:
trip_repo.load_from_csv(Path("tmp/trips.txt"))
stop = stop_repo.find_by_name("Polanka")[0]
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
get_realtime(trip)
if __name__ == "__main__":
main()