matched realtime with static data
This commit is contained in:
parent
2711d13670
commit
94321b1fe6
185
src/realtime.py
185
src/realtime.py
@ -1,14 +1,17 @@
|
|||||||
import zipfile
|
|
||||||
import requests
|
|
||||||
import os
|
|
||||||
from google.transit import gtfs_realtime_pb2
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Iterator, Tuple
|
from datetime import datetime, timedelta
|
||||||
import sqlite3
|
from google.transit import gtfs_realtime_pb2
|
||||||
import csv
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Iterator, Tuple, Dict
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from os import path
|
||||||
|
import requests
|
||||||
|
import sqlite3
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
DATA_PATH: Path = Path('tmp')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Stop:
|
class Stop:
|
||||||
@ -79,12 +82,19 @@ class Trip:
|
|||||||
brigade=int(row["brigade"]),
|
brigade=int(row["brigade"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def normalize_time(time_str: str) -> str:
|
||||||
|
"""Convert 24:00:00 format to 00:00:00 of the next day."""
|
||||||
|
if time_str.startswith('24:'):
|
||||||
|
return '00' + time_str[2:]
|
||||||
|
if time_str.startswith('25:'):
|
||||||
|
return '01' + time_str[2:]
|
||||||
|
return time_str
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StopTime:
|
class StopTime:
|
||||||
trip_id: str
|
trip_id: str
|
||||||
arrival_time: str # Using str for now since SQLite doesn't have a TIME type
|
arrival_time: datetime
|
||||||
departure_time: str
|
departure_time: datetime
|
||||||
stop_id: int
|
stop_id: int
|
||||||
stop_sequence: int
|
stop_sequence: int
|
||||||
stop_headsign: str
|
stop_headsign: str
|
||||||
@ -95,8 +105,8 @@ class StopTime:
|
|||||||
def from_row(cls, row: sqlite3.Row) -> "StopTime":
|
def from_row(cls, row: sqlite3.Row) -> "StopTime":
|
||||||
return cls(
|
return cls(
|
||||||
trip_id=row["trip_id"],
|
trip_id=row["trip_id"],
|
||||||
arrival_time=row["arrival_time"],
|
arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
|
||||||
departure_time=row["departure_time"],
|
departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
|
||||||
stop_id=row["stop_id"],
|
stop_id=row["stop_id"],
|
||||||
stop_sequence=row["stop_sequence"],
|
stop_sequence=row["stop_sequence"],
|
||||||
stop_headsign=row["stop_headsign"],
|
stop_headsign=row["stop_headsign"],
|
||||||
@ -108,8 +118,8 @@ class StopTime:
|
|||||||
def from_csv_row(cls, row) -> "StopTime":
|
def from_csv_row(cls, row) -> "StopTime":
|
||||||
return cls(
|
return cls(
|
||||||
trip_id=row["trip_id"].strip('"'),
|
trip_id=row["trip_id"].strip('"'),
|
||||||
arrival_time=row["arrival_time"],
|
arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"),
|
||||||
departure_time=row["departure_time"],
|
departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"),
|
||||||
stop_id=int(row["stop_id"]),
|
stop_id=int(row["stop_id"]),
|
||||||
stop_sequence=int(row["stop_sequence"]),
|
stop_sequence=int(row["stop_sequence"]),
|
||||||
stop_headsign=row["stop_headsign"].strip('"'),
|
stop_headsign=row["stop_headsign"].strip('"'),
|
||||||
@ -117,8 +127,57 @@ class StopTime:
|
|||||||
drop_off_type=int(row["drop_off_type"]),
|
drop_off_type=int(row["drop_off_type"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RealtimeFrame:
|
||||||
|
trip_id: str
|
||||||
|
route_id: str
|
||||||
|
stop_sequence: int
|
||||||
|
delay: int
|
||||||
|
vehicle_id: str
|
||||||
|
vehicle_label: str
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_feed_entry(entry) -> 'RealtimeFrame':
|
||||||
|
assert len(entry.stop_time_update) == 1
|
||||||
|
return RealtimeFrame(
|
||||||
|
trip_id=entry.trip.trip_id,
|
||||||
|
route_id=entry.trip.route_id,
|
||||||
|
stop_sequence=int(entry.stop_time_update[0].stop_sequence),
|
||||||
|
delay=int(entry.stop_time_update[0].arrival.delay),
|
||||||
|
vehicle_id=entry.vehicle.id,
|
||||||
|
vehicle_label=entry.vehicle.label,
|
||||||
|
timestamp=datetime.fromtimestamp(int(entry.timestamp))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Calendar:
|
||||||
|
start_date: datetime
|
||||||
|
end_date: datetime
|
||||||
|
weekdays: Dict[int, int] = {}
|
||||||
|
keys: Dict[str, int] = {
|
||||||
|
'monday': 0,
|
||||||
|
'tuesday': 1,
|
||||||
|
'wednesday': 2,
|
||||||
|
'thursday': 3,
|
||||||
|
'friday': 4,
|
||||||
|
'saturday': 5,
|
||||||
|
'sunday': 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_from_csv(self, csv_path: Path):
|
||||||
|
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
||||||
|
reader = csv.DictReader(csvfile)
|
||||||
|
for row in reader:
|
||||||
|
for day, id in self.keys.items():
|
||||||
|
if row[day] == '1':
|
||||||
|
self.weekdays[id] = int(row['service_id'])
|
||||||
|
if len(self.weekdays.items()) != 7:
|
||||||
|
raise Exception('Could not match every weekday')
|
||||||
|
|
||||||
|
def get_todays_service_id(self) -> int:
|
||||||
|
return self.weekdays[datetime.now().weekday()]
|
||||||
|
|
||||||
# 2. Create a typed database wrapper
|
|
||||||
class Database:
|
class Database:
|
||||||
conn: sqlite3.Connection
|
conn: sqlite3.Connection
|
||||||
|
|
||||||
@ -142,8 +201,6 @@ class Database:
|
|||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
|
|
||||||
# 3. Create a typed repository for stops
|
|
||||||
class StopRepository:
|
class StopRepository:
|
||||||
db: Database
|
db: Database
|
||||||
|
|
||||||
@ -233,19 +290,21 @@ class StopRepository:
|
|||||||
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
||||||
|
|
||||||
def get_upcoming_departures(
|
def get_upcoming_departures(
|
||||||
self, stop_id: int, limit: int = 10
|
self, calendar: Calendar, stop_id: int, limit: int = 10
|
||||||
) -> List[Tuple[StopTime, Trip]]:
|
) -> List[Tuple[StopTime, Trip]]:
|
||||||
"""Get upcoming departures from a stop with trip information."""
|
"""Get upcoming departures from a stop with trip information."""
|
||||||
|
current_time = datetime.now().strftime("%H:%M:%S")
|
||||||
|
service_id = calendar.get_todays_service_id()
|
||||||
cursor = self.db.execute(
|
cursor = self.db.execute(
|
||||||
"""
|
"""
|
||||||
SELECT st.*, t.*
|
SELECT st.*, t.*
|
||||||
FROM stop_times st
|
FROM stop_times st
|
||||||
JOIN trips t ON st.trip_id = t.trip_id
|
JOIN trips t ON st.trip_id = t.trip_id
|
||||||
WHERE st.stop_id = ?
|
WHERE st.stop_id = ? AND st.departure_time >= ? AND t.service_id = ?
|
||||||
ORDER BY st.departure_time
|
ORDER BY st.departure_time
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""",
|
""",
|
||||||
(stop_id, limit),
|
(stop_id, current_time, service_id, limit),
|
||||||
)
|
)
|
||||||
|
|
||||||
departures = []
|
departures = []
|
||||||
@ -475,39 +534,55 @@ class TripRepository:
|
|||||||
|
|
||||||
|
|
||||||
def download_static_data():
|
def download_static_data():
|
||||||
|
zip_path = '/tmp/data.zip'
|
||||||
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
||||||
open("/tmp/data.zip", "wb").write(response.content)
|
open(zip_path, "wb").write(response.content)
|
||||||
with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
zip_ref.extractall("tmp")
|
zip_ref.extractall("tmp")
|
||||||
os.remove('/tmp/data.zip')
|
os.remove(zip_path)
|
||||||
|
|
||||||
|
|
||||||
def download_realtime_data():
|
def download_realtime_data():
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
|
# "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=feeds.pb"
|
||||||
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
|
"https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb"
|
||||||
)
|
)
|
||||||
open("tmp/trip_updates.pb", "wb").write(response.content)
|
open(path.join(DATA_PATH, 'trip_updates.pb'), "wb").write(response.content)
|
||||||
|
|
||||||
|
|
||||||
def get_realtime(trip: Trip):
|
def get_realtime(trip: Trip) -> None | RealtimeFrame:
|
||||||
with open("tmp/trip_updates.pb", "rb") as f:
|
with open(path.join(DATA_PATH, 'trip_updates.pb'), "rb") as f:
|
||||||
response = f.read()
|
response = f.read()
|
||||||
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
||||||
feed.ParseFromString(response)
|
feed.ParseFromString(response)
|
||||||
|
print(feed, file=open('feed.txt', "w"))
|
||||||
for entity in feed.entity:
|
for entity in feed.entity:
|
||||||
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id):
|
frame = RealtimeFrame.from_feed_entry(entity.trip_update)
|
||||||
print(entity.id)
|
if frame.trip_id == trip.trip_id:
|
||||||
print(entity.trip_update)
|
return frame
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
if not os.path.isdir("tmp"):
|
parser = argparse.ArgumentParser(description='Process a file with optional verbose mode')
|
||||||
download_static_data()
|
|
||||||
|
parser.add_argument('stop', help='Name of the stop')
|
||||||
|
parser.add_argument('-n', '--count', type=int, default=0, help='Departures count')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Download delay data
|
||||||
download_realtime_data()
|
download_realtime_data()
|
||||||
with Database("transit.db") as db:
|
|
||||||
|
# Download static data
|
||||||
|
if not os.path.isdir(DATA_PATH):
|
||||||
|
download_static_data()
|
||||||
|
with Database(path.join(DATA_PATH, "transit.db")) as db:
|
||||||
stop_repo = StopRepository(db)
|
stop_repo = StopRepository(db)
|
||||||
stop_time_repo = StopTimeRepository(db)
|
stop_time_repo = StopTimeRepository(db)
|
||||||
trip_repo = TripRepository(db)
|
trip_repo = TripRepository(db)
|
||||||
|
calendar = Calendar()
|
||||||
|
|
||||||
|
calendar.load_from_csv(Path(path.join(DATA_PATH, 'calendar.txt')))
|
||||||
|
|
||||||
# Create tables
|
# Create tables
|
||||||
stop_repo.create_table()
|
stop_repo.create_table()
|
||||||
@ -516,18 +591,48 @@ def main() -> None:
|
|||||||
|
|
||||||
# Load data if tables are empty
|
# Load data if tables are empty
|
||||||
if stop_repo.is_empty():
|
if stop_repo.is_empty():
|
||||||
stop_repo.load_from_csv(Path("tmp/stops.txt"))
|
stop_repo.load_from_csv(Path(path.join(DATA_PATH, "stops.txt")))
|
||||||
if stop_time_repo.is_empty():
|
if stop_time_repo.is_empty():
|
||||||
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt"))
|
stop_time_repo.load_from_csv(Path(path.join(DATA_PATH, "stop_times.txt")))
|
||||||
if trip_repo.is_empty():
|
if trip_repo.is_empty():
|
||||||
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
trip_repo.load_from_csv(Path(path.join(DATA_PATH, "trips.txt")))
|
||||||
|
|
||||||
stop = stop_repo.find_by_name("Polanka")[0]
|
stops = stop_repo.find_by_name(args.stop)
|
||||||
print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}")
|
for stop in stops:
|
||||||
for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5):
|
print(f"\n\nDepartures from stop: {stop.stop_name} with id: {stop.stop_id}")
|
||||||
print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}")
|
|
||||||
get_realtime(trip)
|
|
||||||
|
|
||||||
|
upcoming_departures = stop_repo.get_upcoming_departures(calendar, stop.stop_id, args.count)
|
||||||
|
|
||||||
|
for stop_time, trip in upcoming_departures:
|
||||||
|
# Check for realtime data
|
||||||
|
realtime: None | RealtimeFrame = get_realtime(trip)
|
||||||
|
|
||||||
|
departure_string: str = (
|
||||||
|
f'{stop_time.arrival_time.strftime('%H:%M:%S')}'
|
||||||
|
'(no realtime data)'
|
||||||
|
)
|
||||||
|
|
||||||
|
if realtime is not None:
|
||||||
|
# Calculate actual departure time
|
||||||
|
delay_str = (
|
||||||
|
f'+{realtime.delay}s'
|
||||||
|
if realtime.delay > 0
|
||||||
|
else f'{realtime.delay}s'
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_time.arrival_time += timedelta(seconds=realtime.delay)
|
||||||
|
last_update_time = datetime.now() - realtime.timestamp
|
||||||
|
|
||||||
|
departure_string = (
|
||||||
|
f'{stop_time.arrival_time.strftime("%H:%M:%S")} '
|
||||||
|
f'({delay_str} updated {int(last_update_time.total_seconds())}s ago)'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nLine: {trip.route_id} "
|
||||||
|
f"Departure: {departure_string} "
|
||||||
|
f"Direction: {trip.trip_headsign}"
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user