matched realtime with static data

This commit is contained in:
Dawid Pietrykowski 2025-02-15 23:51:15 +01:00
parent 2711d13670
commit 94321b1fe6

View File

@ -1,14 +1,17 @@
import zipfile
import requests
import os
from google.transit import gtfs_realtime_pb2
from dataclasses import dataclass
from typing import List, Optional, Iterator, Tuple
import sqlite3
import csv
from datetime import datetime, timedelta
from google.transit import gtfs_realtime_pb2
from pathlib import Path
from typing import List, Optional, Iterator, Tuple, Dict
import argparse
import csv
import os
from os import path
import requests
import sqlite3
import zipfile
DATA_PATH: Path = Path('tmp')
@dataclass
class Stop:
@ -79,12 +82,19 @@ class Trip:
brigade=int(row["brigade"]),
)
def normalize_time(time_str: str) -> str:
"""Convert 24:00:00 format to 00:00:00 of the next day."""
if time_str.startswith('24:'):
return '00' + time_str[2:]
if time_str.startswith('25:'):
return '01' + time_str[2:]
return time_str
@dataclass
class StopTime:
trip_id: str
arrival_time: str # Using str for now since SQLite doesn't have a TIME type
departure_time: str
arrival_time: datetime
departure_time: datetime
stop_id: int
stop_sequence: int
stop_headsign: str
@ -95,8 +105,8 @@ class StopTime:
def from_row(cls, row: sqlite3.Row) -> "StopTime":
return cls(
trip_id=row["trip_id"],
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
stop_id=row["stop_id"],
stop_sequence=row["stop_sequence"],
stop_headsign=row["stop_headsign"],
@ -108,8 +118,8 @@ class StopTime:
def from_csv_row(cls, row) -> "StopTime":
return cls(
trip_id=row["trip_id"].strip('"'),
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
stop_id=int(row["stop_id"]),
stop_sequence=int(row["stop_sequence"]),
stop_headsign=row["stop_headsign"].strip('"'),
@ -117,8 +127,57 @@ class StopTime:
drop_off_type=int(row["drop_off_type"]),
)
@dataclass
class RealtimeFrame:
trip_id: str
route_id: str
stop_sequence: int
delay: int
vehicle_id: str
vehicle_label: str
timestamp: datetime
@staticmethod
def from_feed_entry(entry) -> 'RealtimeFrame':
assert len(entry.stop_time_update) == 1
return RealtimeFrame(
trip_id=entry.trip.trip_id,
route_id=entry.trip.route_id,
stop_sequence=int(entry.stop_time_update[0].stop_sequence),
delay=int(entry.stop_time_update[0].arrival.delay),
vehicle_id=entry.vehicle.id,
vehicle_label=entry.vehicle.label,
timestamp=datetime.fromtimestamp(int(entry.timestamp))
)
class Calendar:
start_date: datetime
end_date: datetime
weekdays: Dict[int, int] = {}
keys: Dict[str, int] = {
'monday': 0,
'tuesday': 1,
'wednesday': 2,
'thursday': 3,
'friday': 4,
'saturday': 5,
'sunday': 6,
}
def load_from_csv(self, csv_path: Path):
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
for day, id in self.keys.items():
if row[day] == '1':
self.weekdays[id] = int(row['service_id'])
if len(self.weekdays.items()) != 7:
raise Exception('Could not match every weekday')
def get_todays_service_id(self) -> int:
return self.weekdays[datetime.now().weekday()]
# 2. Create a typed database wrapper
class Database:
conn: sqlite3.Connection
@ -142,8 +201,6 @@ class Database:
def commit(self) -> None:
self.conn.commit()
# 3. Create a typed repository for stops
class StopRepository:
db: Database
@ -233,19 +290,21 @@ class StopRepository:
return not result[0] # SQLite returns 0 if empty, 1 if not empty
def get_upcoming_departures(
self, stop_id: int, limit: int = 10
self, calendar: Calendar, stop_id: int, limit: int = 10
) -> List[Tuple[StopTime, Trip]]:
"""Get upcoming departures from a stop with trip information."""
current_time = datetime.now().strftime("%H:%M:%S")
service_id = calendar.get_todays_service_id()
cursor = self.db.execute(
"""
SELECT st.*, t.*
FROM stop_times st
JOIN trips t ON st.trip_id = t.trip_id
WHERE st.stop_id = ?
WHERE st.stop_id = ? AND st.departure_time >= ? AND t.service_id = ?
ORDER BY st.departure_time
LIMIT ?
""",
(stop_id, limit),
(stop_id, current_time, service_id, limit),
)
departures = []
@ -475,39 +534,55 @@ class TripRepository:
def download_static_data():
zip_path = '/tmp/data.zip'
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
open("/tmp/data.zip", "wb").write(response.content)
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
open(zip_path, "wb").write(response.content)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall("tmp")
os.remove('/tmp/data.zip')
os.remove(zip_path)
def download_realtime_data():
response = requests.get(
# "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=feeds.pb"
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
)
open("tmp/trip_updates.pb", "wb").write(response.content)
open(path.join(DATA_PATH, 'trip_updates.pb'), "wb").write(response.content)
def get_realtime(trip: Trip):
with open("tmp/trip_updates.pb", "rb") as f:
def get_realtime(trip: Trip) -> None | RealtimeFrame:
with open(path.join(DATA_PATH, 'trip_updates.pb'), "rb") as f:
response = f.read()
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
feed.ParseFromString(response)
print(feed, file=open('feed.txt', "w"))
for entity in feed.entity:
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id):
print(entity.id)
print(entity.trip_update)
frame = RealtimeFrame.from_feed_entry(entity.trip_update)
if frame.trip_id == trip.trip_id:
return frame
def main() -> None:
if not os.path.isdir("tmp"):
parser = argparse.ArgumentParser(description='Process a file with optional verbose mode')
parser.add_argument('stop', help='Name of the stop')
parser.add_argument('-n', '--count', type=int, default=0, help='Departures count')
args = parser.parse_args()
# Download delay data
download_realtime_data()
# Download static data
if not os.path.isdir(DATA_PATH):
download_static_data()
download_realtime_data()
with Database("transit.db") as db:
with Database(path.join(DATA_PATH, "transit.db")) as db:
stop_repo = StopRepository(db)
stop_time_repo = StopTimeRepository(db)
trip_repo = TripRepository(db)
calendar = Calendar()
calendar.load_from_csv(Path(path.join(DATA_PATH, 'calendar.txt')))
# Create tables
stop_repo.create_table()
@ -516,18 +591,48 @@ def main() -> None:
# Load data if tables are empty
if stop_repo.is_empty():
stop_repo.load_from_csv(Path("tmp/stops.txt"))
stop_repo.load_from_csv(Path(path.join(DATA_PATH, "stops.txt")))
if stop_time_repo.is_empty():
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt"))
stop_time_repo.load_from_csv(Path(path.join(DATA_PATH, "stop_times.txt")))
if trip_repo.is_empty():
trip_repo.load_from_csv(Path("tmp/trips.txt"))
trip_repo.load_from_csv(Path(path.join(DATA_PATH, "trips.txt")))
stop = stop_repo.find_by_name("Polanka")[0]
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
get_realtime(trip)
stops = stop_repo.find_by_name(args.stop)
for stop in stops:
print(f"\n\nDepartures from stop: {stop.stop_name} with id: {stop.stop_id}")
upcoming_departures = stop_repo.get_upcoming_departures(calendar, stop.stop_id, args.count)
for stop_time, trip in upcoming_departures:
# Check for realtime data
realtime: None | RealtimeFrame = get_realtime(trip)
departure_string: str = (
f'{stop_time.arrival_time.strftime('%H:%M:%S')}'
'(no realtime data)'
)
if realtime is not None:
# Calculate actual departure time
delay_str = (
f'+{realtime.delay}s'
if realtime.delay > 0
else f'{realtime.delay}s'
)
stop_time.arrival_time += timedelta(seconds=realtime.delay)
last_update_time = datetime.now() - realtime.timestamp
departure_string = (
f'{stop_time.arrival_time.strftime("%H:%M:%S")} '
f'({delay_str} updated {int(last_update_time.total_seconds())}s ago)'
)
print(
f"\nLine: {trip.route_id} "
f"Departure: {departure_string} "
f"Direction: {trip.trip_headsign}"
)
if __name__ == "__main__":
main()