diff --git a/src/realtime.py b/src/realtime.py index 7590ca9..99da173 100755 --- a/src/realtime.py +++ b/src/realtime.py @@ -1,14 +1,17 @@ -import zipfile -import requests -import os -from google.transit import gtfs_realtime_pb2 - from dataclasses import dataclass -from typing import List, Optional, Iterator, Tuple -import sqlite3 -import csv +from datetime import datetime, timedelta +from google.transit import gtfs_realtime_pb2 from pathlib import Path +from typing import List, Optional, Iterator, Tuple, Dict +import argparse +import csv +import os +from os import path +import requests +import sqlite3 +import zipfile +DATA_PATH: Path = Path('tmp') @dataclass class Stop: @@ -79,12 +82,19 @@ class Trip: brigade=int(row["brigade"]), ) +def normalize_time(time_str: str) -> str: + """Convert 24:00:00 format to 00:00:00 of the next day.""" + if time_str.startswith('24:'): + return '00' + time_str[2:] + if time_str.startswith('25:'): + return '01' + time_str[2:] + return time_str @dataclass class StopTime: trip_id: str - arrival_time: str # Using str for now since SQLite doesn't have a TIME type - departure_time: str + arrival_time: datetime + departure_time: datetime stop_id: int stop_sequence: int stop_headsign: str @@ -95,8 +105,8 @@ class StopTime: def from_row(cls, row: sqlite3.Row) -> "StopTime": return cls( trip_id=row["trip_id"], - arrival_time=row["arrival_time"], - departure_time=row["departure_time"], + arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"), + departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"), stop_id=row["stop_id"], stop_sequence=row["stop_sequence"], stop_headsign=row["stop_headsign"], @@ -108,8 +118,8 @@ class StopTime: def from_csv_row(cls, row) -> "StopTime": return cls( trip_id=row["trip_id"].strip('"'), - arrival_time=row["arrival_time"], - departure_time=row["departure_time"], + arrival_time=datetime.strptime(normalize_time(row["arrival_time"]), "%H:%M:%S"), + departure_time=datetime.strptime(normalize_time(row["departure_time"]), "%H:%M:%S"), stop_id=int(row["stop_id"]), stop_sequence=int(row["stop_sequence"]), stop_headsign=row["stop_headsign"].strip('"'), @@ -117,8 +127,57 @@ class StopTime: drop_off_type=int(row["drop_off_type"]), ) +@dataclass +class RealtimeFrame: + trip_id: str + route_id: str + stop_sequence: int + delay: int + vehicle_id: str + vehicle_label: str + timestamp: datetime + + @staticmethod + def from_feed_entry(entry) -> 'RealtimeFrame': + assert len(entry.stop_time_update) == 1 + return RealtimeFrame( + trip_id=entry.trip.trip_id, + route_id=entry.trip.route_id, + stop_sequence=int(entry.stop_time_update[0].stop_sequence), + delay=int(entry.stop_time_update[0].arrival.delay), + vehicle_id=entry.vehicle.id, + vehicle_label=entry.vehicle.label, + timestamp=datetime.fromtimestamp(int(entry.timestamp)) + ) + + +class Calendar: + start_date: datetime + end_date: datetime + weekdays: Dict[int, int] = {} + keys: Dict[str, int] = { + 'monday': 0, + 'tuesday': 1, + 'wednesday': 2, + 'thursday': 3, + 'friday': 4, + 'saturday': 5, + 'sunday': 6, + } + + def load_from_csv(self, csv_path: Path): + with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + for day, id in self.keys.items(): + if row[day] == '1': + self.weekdays[id] = int(row['service_id']) + if len(self.weekdays.items()) != 7: + raise Exception('Could not match every weekday') + + def get_todays_service_id(self) -> int: + return self.weekdays[datetime.now().weekday()] -# 2. Create a typed database wrapper class Database: conn: sqlite3.Connection @@ -142,8 +201,6 @@ class Database: def commit(self) -> None: self.conn.commit() - -# 3. Create a typed repository for stops class StopRepository: db: Database @@ -233,19 +290,21 @@ class StopRepository: return not result[0] # SQLite returns 0 if empty, 1 if not empty def get_upcoming_departures( - self, stop_id: int, limit: int = 10 + self, calendar: Calendar, stop_id: int, limit: int = 10 ) -> List[Tuple[StopTime, Trip]]: """Get upcoming departures from a stop with trip information.""" + current_time = datetime.now().strftime("%H:%M:%S") + service_id = calendar.get_todays_service_id() cursor = self.db.execute( """ SELECT st.*, t.* FROM stop_times st JOIN trips t ON st.trip_id = t.trip_id - WHERE st.stop_id = ? + WHERE st.stop_id = ? AND st.departure_time >= ? AND t.service_id = ? ORDER BY st.departure_time LIMIT ? """, - (stop_id, limit), + (stop_id, current_time, service_id, limit), ) departures = [] @@ -475,39 +534,55 @@ class TripRepository: def download_static_data(): + zip_path = '/tmp/data.zip' response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") - open("/tmp/data.zip", "wb").write(response.content) - with zipfile.ZipFile("/tmp/data.zip", "r") as zip_ref: + open(zip_path, "wb").write(response.content) + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall("tmp") - os.remove('/tmp/data.zip') + os.remove(zip_path) def download_realtime_data(): response = requests.get( + # "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=feeds.pb" "https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb" ) - open("tmp/trip_updates.pb", "wb").write(response.content) + open(path.join(DATA_PATH, 'trip_updates.pb'), "wb").write(response.content) -def get_realtime(trip: Trip): - with open("tmp/trip_updates.pb", "rb") as f: +def get_realtime(trip: Trip) -> None | RealtimeFrame: + with open(path.join(DATA_PATH, 'trip_updates.pb'), "rb") as f: response = f.read() feed = gtfs_realtime_pb2.FeedMessage() # type: ignore feed.ParseFromString(response) + print(feed, file=open('feed.txt', "w")) for entity in feed.entity: - if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id): - print(entity.id) - print(entity.trip_update) + frame = RealtimeFrame.from_feed_entry(entity.trip_update) + if frame.trip_id == trip.trip_id: + return frame def main() -> None: - if not os.path.isdir("tmp"): + parser = argparse.ArgumentParser(description='Process a file with optional verbose mode') + + parser.add_argument('stop', help='Name of the stop') + parser.add_argument('-n', '--count', type=int, default=0, help='Departures count') + + args = parser.parse_args() + + # Download delay data + download_realtime_data() + + # Download static data + if not os.path.isdir(DATA_PATH): download_static_data() - download_realtime_data() - with Database("transit.db") as db: + with Database(path.join(DATA_PATH, "transit.db")) as db: stop_repo = StopRepository(db) stop_time_repo = StopTimeRepository(db) trip_repo = TripRepository(db) + calendar = Calendar() + + calendar.load_from_csv(Path(path.join(DATA_PATH, 'calendar.txt'))) # Create tables stop_repo.create_table() @@ -516,18 +591,48 @@ def main() -> None: # Load data if tables are empty if stop_repo.is_empty(): - stop_repo.load_from_csv(Path("tmp/stops.txt")) + stop_repo.load_from_csv(Path(path.join(DATA_PATH, "stops.txt"))) if stop_time_repo.is_empty(): - stop_time_repo.load_from_csv(Path("tmp/stop_times.txt")) + stop_time_repo.load_from_csv(Path(path.join(DATA_PATH, "stop_times.txt"))) if trip_repo.is_empty(): - trip_repo.load_from_csv(Path("tmp/trips.txt")) + trip_repo.load_from_csv(Path(path.join(DATA_PATH, "trips.txt"))) - stop = stop_repo.find_by_name("Polanka")[0] - print(f"Querying stop: {stop.stop_name} with id: {stop.stop_id}") - for stop_time, trip in stop_repo.get_upcoming_departures(stop.stop_id, 5): - print(f"Line: {trip.route_id} Departure: {stop_time.arrival_time}") - get_realtime(trip) + stops = stop_repo.find_by_name(args.stop) + for stop in stops: + print(f"\n\nDepartures from stop: {stop.stop_name} with id: {stop.stop_id}") + upcoming_departures = stop_repo.get_upcoming_departures(calendar, stop.stop_id, args.count) + + for stop_time, trip in upcoming_departures: + # Check for realtime data + realtime: None | RealtimeFrame = get_realtime(trip) + + departure_string: str = ( + f'{stop_time.arrival_time.strftime('%H:%M:%S')}' + '(no realtime data)' + ) + + if realtime is not None: + # Calculate actual departure time + delay_str = ( + f'+{realtime.delay}s' + if realtime.delay > 0 + else f'{realtime.delay}s' + ) + + stop_time.arrival_time += timedelta(seconds=realtime.delay) + last_update_time = datetime.now() - realtime.timestamp + + departure_string = ( + f'{stop_time.arrival_time.strftime("%H:%M:%S")} ' + f'({delay_str} updated {int(last_update_time.total_seconds())}s ago)' + ) + + print( + f"\nLine: {trip.route_id} " + f"Departure: {departure_string} " + f"Direction: {trip.trip_headsign}" + ) if __name__ == "__main__": main()