matched realtime with static data

This commit is contained in:
Dawid Pietrykowski 2025-02-15 23:51:15 +01:00
parent 2711d13670
commit 94321b1fe6

View File

@ -1,14 +1,17 @@
import zipfile
import requests
import os
from google.transit import gtfs_realtime_pb2
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Iterator, Tuple from datetime import datetime, timedelta
import sqlite3 from google.transit import gtfs_realtime_pb2
import csv
from pathlib import Path from pathlib import Path
from typing import List, Optional, Iterator, Tuple, Dict
import argparse
import csv
import os
from os import path
import requests
import sqlite3
import zipfile
DATA_PATH: Path = Path('tmp')
@dataclass @dataclass
class Stop: class Stop:
@ -79,12 +82,19 @@ class Trip:
brigade=int(row["brigade"]), brigade=int(row["brigade"]),
) )
def normalize_time(time_str: str) -> str:
"""Convert 24:00:00 format to 00:00:00 of the next day."""
if time_str.startswith('24:'):
return '00' + time_str[2:]
if time_str.startswith('25:'):
return '01' + time_str[2:]
return time_str
@dataclass @dataclass
class StopTime: class StopTime:
trip_id: str trip_id: str
arrival_time: str # Using str for now since SQLite doesn't have a TIME type arrival_time: datetime
departure_time: str departure_time: datetime
stop_id: int stop_id: int
stop_sequence: int stop_sequence: int
stop_headsign: str stop_headsign: str
@ -95,8 +105,8 @@ class StopTime:
def from_row(cls, row: sqlite3.Row) -> "StopTime": def from_row(cls, row: sqlite3.Row) -> "StopTime":
return cls( return cls(
trip_id=row["trip_id"], trip_id=row["trip_id"],
arrival_time=row["arrival_time"], arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
departure_time=row["departure_time"], departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
stop_id=row["stop_id"], stop_id=row["stop_id"],
stop_sequence=row["stop_sequence"], stop_sequence=row["stop_sequence"],
stop_headsign=row["stop_headsign"], stop_headsign=row["stop_headsign"],
@ -108,8 +118,8 @@ class StopTime:
def from_csv_row(cls, row) -> "StopTime": def from_csv_row(cls, row) -> "StopTime":
return cls( return cls(
trip_id=row["trip_id"].strip('"'), trip_id=row["trip_id"].strip('"'),
arrival_time=row["arrival_time"], arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
departure_time=row["departure_time"], departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
stop_id=int(row["stop_id"]), stop_id=int(row["stop_id"]),
stop_sequence=int(row["stop_sequence"]), stop_sequence=int(row["stop_sequence"]),
stop_headsign=row["stop_headsign"].strip('"'), stop_headsign=row["stop_headsign"].strip('"'),
@ -117,8 +127,57 @@ class StopTime:
drop_off_type=int(row["drop_off_type"]), drop_off_type=int(row["drop_off_type"]),
) )
@dataclass
class RealtimeFrame:
trip_id: str
route_id: str
stop_sequence: int
delay: int
vehicle_id: str
vehicle_label: str
timestamp: datetime
@staticmethod
def from_feed_entry(entry) -> 'RealtimeFrame':
assert len(entry.stop_time_update) == 1
return RealtimeFrame(
trip_id=entry.trip.trip_id,
route_id=entry.trip.route_id,
stop_sequence=int(entry.stop_time_update[0].stop_sequence),
delay=int(entry.stop_time_update[0].arrival.delay),
vehicle_id=entry.vehicle.id,
vehicle_label=entry.vehicle.label,
timestamp=datetime.fromtimestamp(int(entry.timestamp))
)
class Calendar:
start_date: datetime
end_date: datetime
weekdays: Dict[int, int] = {}
keys: Dict[str, int] = {
'monday': 0,
'tuesday': 1,
'wednesday': 2,
'thursday': 3,
'friday': 4,
'saturday': 5,
'sunday': 6,
}
def load_from_csv(self, csv_path: Path):
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
for day, id in self.keys.items():
if row[day] == '1':
self.weekdays[id] = int(row['service_id'])
if len(self.weekdays.items()) != 7:
raise Exception('Could not match every weekday')
def get_todays_service_id(self) -> int:
return self.weekdays[datetime.now().weekday()]
# 2. Create a typed database wrapper
class Database: class Database:
conn: sqlite3.Connection conn: sqlite3.Connection
@ -142,8 +201,6 @@ class Database:
def commit(self) -> None: def commit(self) -> None:
self.conn.commit() self.conn.commit()
# 3. Create a typed repository for stops
class StopRepository: class StopRepository:
db: Database db: Database
@ -233,19 +290,21 @@ class StopRepository:
return not result[0] # SQLite returns 0 if empty, 1 if not empty return not result[0] # SQLite returns 0 if empty, 1 if not empty
def get_upcoming_departures( def get_upcoming_departures(
self, stop_id: int, limit: int = 10 self, calendar: Calendar, stop_id: int, limit: int = 10
) -> List[Tuple[StopTime, Trip]]: ) -> List[Tuple[StopTime, Trip]]:
"""Get upcoming departures from a stop with trip information.""" """Get upcoming departures from a stop with trip information."""
current_time = datetime.now().strftime("%H:%M:%S")
service_id = calendar.get_todays_service_id()
cursor = self.db.execute( cursor = self.db.execute(
""" """
SELECT st.*, t.* SELECT st.*, t.*
FROM stop_times st FROM stop_times st
JOIN trips t ON st.trip_id = t.trip_id JOIN trips t ON st.trip_id = t.trip_id
WHERE st.stop_id = ? WHERE st.stop_id = ? AND st.departure_time >= ? AND t.service_id = ?
ORDER BY st.departure_time ORDER BY st.departure_time
LIMIT ? LIMIT ?
""", """,
(stop_id, limit), (stop_id, current_time, service_id, limit),
) )
departures = [] departures = []
@ -475,39 +534,55 @@ class TripRepository:
def download_static_data(): def download_static_data():
zip_path = '/tmp/data.zip'
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
open("/tmp/data.zip", "wb").write(response.content) open(zip_path, "wb").write(response.content)
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref: with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall("tmp") zip_ref.extractall("tmp")
os.remove('/tmp/data.zip') os.remove(zip_path)
def download_realtime_data(): def download_realtime_data():
response = requests.get( response = requests.get(
# "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=feeds.pb"
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb" "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
) )
open("tmp/trip_updates.pb", "wb").write(response.content) open(path.join(DATA_PATH, 'trip_updates.pb'), "wb").write(response.content)
def get_realtime(trip: Trip): def get_realtime(trip: Trip) -> None | RealtimeFrame:
with open("tmp/trip_updates.pb", "rb") as f: with open(path.join(DATA_PATH, 'trip_updates.pb'), "rb") as f:
response = f.read() response = f.read()
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
feed.ParseFromString(response) feed.ParseFromString(response)
print(feed, file=open('feed.txt', "w"))
for entity in feed.entity: for entity in feed.entity:
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id): frame = RealtimeFrame.from_feed_entry(entity.trip_update)
print(entity.id) if frame.trip_id == trip.trip_id:
print(entity.trip_update) return frame
def main() -> None: def main() -> None:
if not os.path.isdir("tmp"): parser = argparse.ArgumentParser(description='Process a file with optional verbose mode')
parser.add_argument('stop', help='Name of the stop')
parser.add_argument('-n', '--count', type=int, default=0, help='Departures count')
args = parser.parse_args()
# Download delay data
download_realtime_data()
# Download static data
if not os.path.isdir(DATA_PATH):
download_static_data() download_static_data()
download_realtime_data() with Database(path.join(DATA_PATH, "transit.db")) as db:
with Database("transit.db") as db:
stop_repo = StopRepository(db) stop_repo = StopRepository(db)
stop_time_repo = StopTimeRepository(db) stop_time_repo = StopTimeRepository(db)
trip_repo = TripRepository(db) trip_repo = TripRepository(db)
calendar = Calendar()
calendar.load_from_csv(Path(path.join(DATA_PATH, 'calendar.txt')))
# Create tables # Create tables
stop_repo.create_table() stop_repo.create_table()
@ -516,18 +591,48 @@ def main() -> None:
# Load data if tables are empty # Load data if tables are empty
if stop_repo.is_empty(): if stop_repo.is_empty():
stop_repo.load_from_csv(Path("tmp/stops.txt")) stop_repo.load_from_csv(Path(path.join(DATA_PATH, "stops.txt")))
if stop_time_repo.is_empty(): if stop_time_repo.is_empty():
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt")) stop_time_repo.load_from_csv(Path(path.join(DATA_PATH, "stop_times.txt")))
if trip_repo.is_empty(): if trip_repo.is_empty():
trip_repo.load_from_csv(Path("tmp/trips.txt")) trip_repo.load_from_csv(Path(path.join(DATA_PATH, "trips.txt")))
stop = stop_repo.find_by_name("Polanka")[0] stops = stop_repo.find_by_name(args.stop)
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}") for stop in stops:
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5): print(f"\n\nDepartures from stop: {stop.stop_name} with id: {stop.stop_id}")
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
get_realtime(trip)
upcoming_departures = stop_repo.get_upcoming_departures(calendar, stop.stop_id, args.count)
for stop_time, trip in upcoming_departures:
# Check for realtime data
realtime: None | RealtimeFrame = get_realtime(trip)
departure_string: str = (
f'{stop_time.arrival_time.strftime('%H:%M:%S')}'
'(no realtime data)'
)
if realtime is not None:
# Calculate actual departure time
delay_str = (
f'+{realtime.delay}s'
if realtime.delay > 0
else f'{realtime.delay}s'
)
stop_time.arrival_time += timedelta(seconds=realtime.delay)
last_update_time = datetime.now() - realtime.timestamp
departure_string = (
f'{stop_time.arrival_time.strftime("%H:%M:%S")} '
f'({delay_str} updated {int(last_update_time.total_seconds())}s ago)'
)
print(
f"\nLine: {trip.route_id} "
f"Departure: {departure_string} "
f"Direction: {trip.trip_headsign}"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()