refactor, download if needed
This commit is contained in:
parent
947fa37aac
commit
2711d13670
@ -1,3 +1,9 @@
|
||||
### Getting started
|
||||
|
||||
To run a simple demo script you can use `uv` with:
|
||||
|
||||
uv run src/realtime.py
|
||||
|
||||
### API documentation
|
||||
|
||||
##### ZTM API docs
|
||||
|
@ -1,5 +1,6 @@
|
||||
import zipfile
|
||||
import requests
|
||||
import os
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -40,6 +41,7 @@ class Stop:
|
||||
zone_id=row["zone_id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trip:
|
||||
route_id: str
|
||||
@ -77,6 +79,7 @@ class Trip:
|
||||
brigade=int(row["brigade"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopTime:
|
||||
trip_id: str
|
||||
@ -114,6 +117,7 @@ class StopTime:
|
||||
drop_off_type=int(row["drop_off_type"]),
|
||||
)
|
||||
|
||||
|
||||
# 2. Create a typed database wrapper
|
||||
class Database:
|
||||
conn: sqlite3.Connection
|
||||
@ -228,32 +232,37 @@ class StopRepository:
|
||||
result = cursor.fetchone()
|
||||
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
||||
|
||||
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
|
||||
def get_upcoming_departures(
|
||||
self, stop_id: int, limit: int = 10
|
||||
) -> List[Tuple[StopTime, Trip]]:
|
||||
"""Get upcoming departures from a stop with trip information."""
|
||||
cursor = self.db.execute("""
|
||||
cursor = self.db.execute(
|
||||
"""
|
||||
SELECT st.*, t.*
|
||||
FROM stop_times st
|
||||
JOIN trips t ON st.trip_id = t.trip_id
|
||||
WHERE st.stop_id = ?
|
||||
ORDER BY st.departure_time
|
||||
LIMIT ?
|
||||
""", (stop_id, limit))
|
||||
|
||||
""",
|
||||
(stop_id, limit),
|
||||
)
|
||||
|
||||
departures = []
|
||||
for row in cursor:
|
||||
stop_time = StopTime.from_row(row)
|
||||
trip = Trip(
|
||||
route_id=row['route_id'],
|
||||
service_id=row['service_id'],
|
||||
trip_id=row['trip_id'],
|
||||
trip_headsign=row['trip_headsign'],
|
||||
direction_id=row['direction_id'],
|
||||
shape_id=row['shape_id'],
|
||||
wheelchair_accessible=row['wheelchair_accessible'],
|
||||
brigade=row['brigade']
|
||||
route_id=row["route_id"],
|
||||
service_id=row["service_id"],
|
||||
trip_id=row["trip_id"],
|
||||
trip_headsign=row["trip_headsign"],
|
||||
direction_id=row["direction_id"],
|
||||
shape_id=row["shape_id"],
|
||||
wheelchair_accessible=row["wheelchair_accessible"],
|
||||
brigade=row["brigade"],
|
||||
)
|
||||
departures.append((stop_time, trip))
|
||||
|
||||
|
||||
return departures
|
||||
|
||||
|
||||
@ -344,14 +353,14 @@ class StopTimeRepository:
|
||||
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
|
||||
(trip_id,)
|
||||
(trip_id,),
|
||||
)
|
||||
return [StopTime.from_row(row) for row in cursor]
|
||||
|
||||
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
|
||||
(stop_id,)
|
||||
(stop_id,),
|
||||
)
|
||||
return [StopTime.from_row(row) for row in cursor]
|
||||
|
||||
@ -360,6 +369,7 @@ class StopTimeRepository:
|
||||
result = cursor.fetchone()
|
||||
return not result[0]
|
||||
|
||||
|
||||
class TripRepository:
|
||||
db: Database
|
||||
|
||||
@ -454,8 +464,7 @@ class TripRepository:
|
||||
|
||||
def get_by_headsign(self, headsign: str) -> List[Trip]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
|
||||
(f"%{headsign}%",)
|
||||
"SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",)
|
||||
)
|
||||
return [Trip.from_row(row) for row in cursor]
|
||||
|
||||
@ -464,14 +473,24 @@ class TripRepository:
|
||||
result = cursor.fetchone()
|
||||
return not result[0]
|
||||
|
||||
def download_data():
|
||||
|
||||
def download_static_data():
|
||||
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
||||
open("data.zip", "wb").write(response.content)
|
||||
with zipfile.ZipFile("data.zip", "r") as zip_ref:
|
||||
open("/tmp/data.zip", "wb").write(response.content)
|
||||
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
|
||||
zip_ref.extractall("tmp")
|
||||
os.remove('/tmp/data.zip')
|
||||
|
||||
|
||||
def download_realtime_data():
|
||||
response = requests.get(
|
||||
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
|
||||
)
|
||||
open("tmp/trip_updates.pb", "wb").write(response.content)
|
||||
|
||||
|
||||
def get_realtime(trip: Trip):
|
||||
with open("trip_updates.pb", "rb") as f:
|
||||
with open("tmp/trip_updates.pb", "rb") as f:
|
||||
response = f.read()
|
||||
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
||||
feed.ParseFromString(response)
|
||||
@ -480,7 +499,11 @@ def get_realtime(trip: Trip):
|
||||
print(entity.id)
|
||||
print(entity.trip_update)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not os.path.isdir("tmp"):
|
||||
download_static_data()
|
||||
download_realtime_data()
|
||||
with Database("transit.db") as db:
|
||||
stop_repo = StopRepository(db)
|
||||
stop_time_repo = StopTimeRepository(db)
|
||||
@ -500,11 +523,11 @@ def main() -> None:
|
||||
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
||||
|
||||
stop = stop_repo.find_by_name("Polanka")[0]
|
||||
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
|
||||
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
|
||||
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
|
||||
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
|
||||
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
|
||||
get_realtime(trip)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user