refactor, download if needed
This commit is contained in:
parent
947fa37aac
commit
2711d13670
@ -1,3 +1,9 @@
|
|||||||
|
### Getting started
|
||||||
|
|
||||||
|
To run a simple demo script you can use `uv` with:
|
||||||
|
|
||||||
|
uv run src/realtime.py
|
||||||
|
|
||||||
### API documentation
|
### API documentation
|
||||||
|
|
||||||
##### ZTM API docs
|
##### ZTM API docs
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import zipfile
|
import zipfile
|
||||||
import requests
|
import requests
|
||||||
|
import os
|
||||||
from google.transit import gtfs_realtime_pb2
|
from google.transit import gtfs_realtime_pb2
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -40,6 +41,7 @@ class Stop:
|
|||||||
zone_id=row["zone_id"],
|
zone_id=row["zone_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Trip:
|
class Trip:
|
||||||
route_id: str
|
route_id: str
|
||||||
@ -77,6 +79,7 @@ class Trip:
|
|||||||
brigade=int(row["brigade"]),
|
brigade=int(row["brigade"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StopTime:
|
class StopTime:
|
||||||
trip_id: str
|
trip_id: str
|
||||||
@ -114,6 +117,7 @@ class StopTime:
|
|||||||
drop_off_type=int(row["drop_off_type"]),
|
drop_off_type=int(row["drop_off_type"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 2. Create a typed database wrapper
|
# 2. Create a typed database wrapper
|
||||||
class Database:
|
class Database:
|
||||||
conn: sqlite3.Connection
|
conn: sqlite3.Connection
|
||||||
@ -228,32 +232,37 @@ class StopRepository:
|
|||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
||||||
|
|
||||||
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
|
def get_upcoming_departures(
|
||||||
|
self, stop_id: int, limit: int = 10
|
||||||
|
) -> List[Tuple[StopTime, Trip]]:
|
||||||
"""Get upcoming departures from a stop with trip information."""
|
"""Get upcoming departures from a stop with trip information."""
|
||||||
cursor = self.db.execute("""
|
cursor = self.db.execute(
|
||||||
|
"""
|
||||||
SELECT st.*, t.*
|
SELECT st.*, t.*
|
||||||
FROM stop_times st
|
FROM stop_times st
|
||||||
JOIN trips t ON st.trip_id = t.trip_id
|
JOIN trips t ON st.trip_id = t.trip_id
|
||||||
WHERE st.stop_id = ?
|
WHERE st.stop_id = ?
|
||||||
ORDER BY st.departure_time
|
ORDER BY st.departure_time
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""", (stop_id, limit))
|
""",
|
||||||
|
(stop_id, limit),
|
||||||
|
)
|
||||||
|
|
||||||
departures = []
|
departures = []
|
||||||
for row in cursor:
|
for row in cursor:
|
||||||
stop_time = StopTime.from_row(row)
|
stop_time = StopTime.from_row(row)
|
||||||
trip = Trip(
|
trip = Trip(
|
||||||
route_id=row['route_id'],
|
route_id=row["route_id"],
|
||||||
service_id=row['service_id'],
|
service_id=row["service_id"],
|
||||||
trip_id=row['trip_id'],
|
trip_id=row["trip_id"],
|
||||||
trip_headsign=row['trip_headsign'],
|
trip_headsign=row["trip_headsign"],
|
||||||
direction_id=row['direction_id'],
|
direction_id=row["direction_id"],
|
||||||
shape_id=row['shape_id'],
|
shape_id=row["shape_id"],
|
||||||
wheelchair_accessible=row['wheelchair_accessible'],
|
wheelchair_accessible=row["wheelchair_accessible"],
|
||||||
brigade=row['brigade']
|
brigade=row["brigade"],
|
||||||
)
|
)
|
||||||
departures.append((stop_time, trip))
|
departures.append((stop_time, trip))
|
||||||
|
|
||||||
return departures
|
return departures
|
||||||
|
|
||||||
|
|
||||||
@ -344,14 +353,14 @@ class StopTimeRepository:
|
|||||||
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
|
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
|
||||||
cursor = self.db.execute(
|
cursor = self.db.execute(
|
||||||
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
|
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
|
||||||
(trip_id,)
|
(trip_id,),
|
||||||
)
|
)
|
||||||
return [StopTime.from_row(row) for row in cursor]
|
return [StopTime.from_row(row) for row in cursor]
|
||||||
|
|
||||||
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
|
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
|
||||||
cursor = self.db.execute(
|
cursor = self.db.execute(
|
||||||
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
|
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
|
||||||
(stop_id,)
|
(stop_id,),
|
||||||
)
|
)
|
||||||
return [StopTime.from_row(row) for row in cursor]
|
return [StopTime.from_row(row) for row in cursor]
|
||||||
|
|
||||||
@ -360,6 +369,7 @@ class StopTimeRepository:
|
|||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return not result[0]
|
return not result[0]
|
||||||
|
|
||||||
|
|
||||||
class TripRepository:
|
class TripRepository:
|
||||||
db: Database
|
db: Database
|
||||||
|
|
||||||
@ -454,8 +464,7 @@ class TripRepository:
|
|||||||
|
|
||||||
def get_by_headsign(self, headsign: str) -> List[Trip]:
|
def get_by_headsign(self, headsign: str) -> List[Trip]:
|
||||||
cursor = self.db.execute(
|
cursor = self.db.execute(
|
||||||
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
|
"SELECT * FROM trips WHERE trip_headsign LIKE ?", (f"%{headsign}%",)
|
||||||
(f"%{headsign}%",)
|
|
||||||
)
|
)
|
||||||
return [Trip.from_row(row) for row in cursor]
|
return [Trip.from_row(row) for row in cursor]
|
||||||
|
|
||||||
@ -464,14 +473,24 @@ class TripRepository:
|
|||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return not result[0]
|
return not result[0]
|
||||||
|
|
||||||
def download_data():
|
|
||||||
|
def download_static_data():
|
||||||
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
||||||
open("data.zip", "wb").write(response.content)
|
open("/tmp/data.zip", "wb").write(response.content)
|
||||||
with zipfile.ZipFile("data.zip", "r") as zip_ref:
|
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
|
||||||
zip_ref.extractall("tmp")
|
zip_ref.extractall("tmp")
|
||||||
|
os.remove('/tmp/data.zip')
|
||||||
|
|
||||||
|
|
||||||
|
def download_realtime_data():
|
||||||
|
response = requests.get(
|
||||||
|
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
|
||||||
|
)
|
||||||
|
open("tmp/trip_updates.pb", "wb").write(response.content)
|
||||||
|
|
||||||
|
|
||||||
def get_realtime(trip: Trip):
|
def get_realtime(trip: Trip):
|
||||||
with open("trip_updates.pb", "rb") as f:
|
with open("tmp/trip_updates.pb", "rb") as f:
|
||||||
response = f.read()
|
response = f.read()
|
||||||
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
||||||
feed.ParseFromString(response)
|
feed.ParseFromString(response)
|
||||||
@ -480,7 +499,11 @@ def get_realtime(trip: Trip):
|
|||||||
print(entity.id)
|
print(entity.id)
|
||||||
print(entity.trip_update)
|
print(entity.trip_update)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
if not os.path.isdir("tmp"):
|
||||||
|
download_static_data()
|
||||||
|
download_realtime_data()
|
||||||
with Database("transit.db") as db:
|
with Database("transit.db") as db:
|
||||||
stop_repo = StopRepository(db)
|
stop_repo = StopRepository(db)
|
||||||
stop_time_repo = StopTimeRepository(db)
|
stop_time_repo = StopTimeRepository(db)
|
||||||
@ -500,11 +523,11 @@ def main() -> None:
|
|||||||
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
||||||
|
|
||||||
stop = stop_repo.find_by_name("Polanka")[0]
|
stop = stop_repo.find_by_name("Polanka")[0]
|
||||||
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
|
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
|
||||||
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
|
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
|
||||||
|
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
|
||||||
get_realtime(trip)
|
get_realtime(trip)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
x
Reference in New Issue
Block a user