diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dd4e282 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +tmp +proto +*.txt +*.pb +*.db +**/__pycache__ +uv.lock diff --git a/README.md b/README.md index 90c8fae..ab49c23 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,6 @@ ##### ZTM API docs https://www.ztm.poznan.pl/otwarte-dane/dla-deweloperow/ ##### ZTM API - static data (routes) - curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.pb + curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.zip ##### ZTM API - realtime data (delays) curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb' -o trip_updates.pb diff --git a/scripts/realtime.py b/scripts/realtime.py index 0d3b024..68e19c8 100755 --- a/scripts/realtime.py +++ b/scripts/realtime.py @@ -1,23 +1,510 @@ -from google.transit import gtfs_realtime_pb2 +import zipfile import requests +from google.transit import gtfs_realtime_pb2 -feed = gtfs_realtime_pb2.FeedMessage() -# response = requests.get('https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb') -# response = '' -# with open('trip_updates.pb', 'rb') as f: -# response = f.read() -# feed.ParseFromString(response) -# for entity in feed.entity: -# print(entity) -# if entity.HasField('trip_update'): -# print(entity.trip_update) +from dataclasses import dataclass +from typing import List, Optional, Iterator, Tuple +import sqlite3 +import csv +from pathlib import Path - -with open('feeds.pb', 'rb') as f: - response = f.read() -feed = gtfs_realtime_pb2.FeedMessage() -feed.ParseFromString(response) -for entity in feed.entity: - print(entity) - # if entity.HasField('trip_update'): - # print(entity.trip_update) + +@dataclass +class Stop: + stop_id: int + stop_code: str + stop_name: str + stop_lat: float + stop_lon: float + zone_id: str + + @classmethod + def from_row(cls, row: sqlite3.Row) -> "Stop": + return cls( + stop_id=row["stop_id"], + stop_code=row["stop_code"], + stop_name=row["stop_name"], + stop_lat=row["stop_lat"], + stop_lon=row["stop_lon"], + zone_id=row["zone_id"], + ) + + @classmethod + def from_csv_row(cls, row) -> "Stop": + return cls( + stop_id=int(row["stop_id"]), + stop_code=row["stop_code"].strip('"'), + stop_name=row["stop_name"].strip('"'), + stop_lat=float(row["stop_lat"]), + stop_lon=float(row["stop_lon"]), + zone_id=row["zone_id"], + ) + +@dataclass +class Trip: + route_id: str + service_id: str + trip_id: str + trip_headsign: str + direction_id: int + shape_id: int + wheelchair_accessible: int + brigade: int + + @classmethod + def from_row(cls, row: sqlite3.Row) -> "Trip": + return cls( + route_id=row["route_id"], + service_id=row["service_id"], + trip_id=row["trip_id"], + trip_headsign=row["trip_headsign"], + direction_id=row["direction_id"], + shape_id=row["shape_id"], + wheelchair_accessible=row["wheelchair_accessible"], + brigade=row["brigade"], + ) + + @classmethod + def from_csv_row(cls, row) -> "Trip": + return cls( + route_id=row["route_id"].strip('"'), + service_id=row["service_id"].strip('"'), + trip_id=row["trip_id"].strip('"'), + trip_headsign=row["trip_headsign"].strip('"'), + direction_id=int(row["direction_id"]), + shape_id=int(row["shape_id"]), + wheelchair_accessible=int(row["wheelchair_accessible"]), + brigade=int(row["brigade"]), + ) + +@dataclass +class StopTime: + trip_id: str + arrival_time: str # Using str for now since SQLite doesn't have a TIME type + departure_time: str + stop_id: int + stop_sequence: int + stop_headsign: str + pickup_type: int + drop_off_type: int + + @classmethod + def from_row(cls, row: sqlite3.Row) -> "StopTime": + return cls( + trip_id=row["trip_id"], + arrival_time=row["arrival_time"], + departure_time=row["departure_time"], + stop_id=row["stop_id"], + stop_sequence=row["stop_sequence"], + stop_headsign=row["stop_headsign"], + pickup_type=row["pickup_type"], + drop_off_type=row["drop_off_type"], + ) + + @classmethod + def from_csv_row(cls, row) -> "StopTime": + return cls( + trip_id=row["trip_id"].strip('"'), + arrival_time=row["arrival_time"], + departure_time=row["departure_time"], + stop_id=int(row["stop_id"]), + stop_sequence=int(row["stop_sequence"]), + stop_headsign=row["stop_headsign"].strip('"'), + pickup_type=int(row["pickup_type"]), + drop_off_type=int(row["drop_off_type"]), + ) + +# 2. Create a typed database wrapper +class Database: + conn: sqlite3.Connection + + def __init__(self, db_path: str): + self.db_path = db_path + + def __enter__(self) -> "Database": + self.conn = sqlite3.connect(self.db_path) + self.conn.row_factory = sqlite3.Row + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.conn.close() + + def execute(self, query: str, params={}) -> sqlite3.Cursor: + return self.conn.execute(query, params) + + def executemany(self, query: str, params_list=[]) -> sqlite3.Cursor: + return self.conn.executemany(query, params_list) + + def commit(self) -> None: + self.conn.commit() + + +# 3. Create a typed repository for stops +class StopRepository: + db: Database + + def __init__(self, db: Database): + self.db = db + + def create_table(self) -> None: + self.db.execute(""" + CREATE TABLE IF NOT EXISTS stops ( + stop_id INTEGER PRIMARY KEY, + stop_code TEXT NOT NULL, + stop_name TEXT NOT NULL, + stop_lat REAL NOT NULL, + stop_lon REAL NOT NULL, + zone_id TEXT NOT NULL + ) + """) + self.db.commit() + + def insert(self, stop: Stop) -> None: + self.db.execute( + """ + INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + stop.stop_id, + stop.stop_code, + stop.stop_name, + stop.stop_lat, + stop.stop_lon, + stop.zone_id, + ), + ) + self.db.commit() + + def bulk_insert(self, stops: list[Stop]) -> None: + """Insert multiple stops efficiently.""" + self.db.executemany( + """ + INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id) + VALUES (?, ?, ?, ?, ?, ?) + """, + [ + (s.stop_id, s.stop_code, s.stop_name, s.stop_lat, s.stop_lon, s.zone_id) + for s in stops + ], + ) + self.db.commit() + + def load_from_csv(self, csv_path: Path) -> None: + """Load stops from a CSV file.""" + stops: list[Stop] = [] + + with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + try: + stop = Stop.from_csv_row(row) + stops.append(stop) + except (ValueError, KeyError) as e: + print(f"Error processing row: {row}. Error: {e}") + + if stops: + self.bulk_insert(stops) + + def get_by_id(self, stop_id: int) -> Optional[Stop]: + cursor = self.db.execute("SELECT * FROM stops WHERE stop_id = ?", (stop_id,)) + row = cursor.fetchone() + return Stop.from_row(row) if row else None + + def get_all(self) -> Iterator[Stop]: + cursor = self.db.execute("SELECT * FROM stops") + for row in cursor: + yield Stop.from_row(row) + + def find_by_name(self, name: str) -> List[Stop]: + cursor = self.db.execute( + "SELECT * FROM stops WHERE stop_name LIKE ?", (f"%{name}%",) + ) + return [Stop.from_row(row) for row in cursor] + + def is_empty(self) -> bool: + """Check if the stops table is empty.""" + cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stops LIMIT 1)") + result = cursor.fetchone() + return not result[0] # SQLite returns 0 if empty, 1 if not empty + + def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]: + """Get upcoming departures from a stop with trip information.""" + cursor = self.db.execute(""" + SELECT st.*, t.* + FROM stop_times st + JOIN trips t ON st.trip_id = t.trip_id + WHERE st.stop_id = ? + ORDER BY st.departure_time + LIMIT ? + """, (stop_id, limit)) + + departures = [] + for row in cursor: + stop_time = StopTime.from_row(row) + trip = Trip( + route_id=row['route_id'], + service_id=row['service_id'], + trip_id=row['trip_id'], + trip_headsign=row['trip_headsign'], + direction_id=row['direction_id'], + shape_id=row['shape_id'], + wheelchair_accessible=row['wheelchair_accessible'], + brigade=row['brigade'] + ) + departures.append((stop_time, trip)) + + return departures + + +class StopTimeRepository: + db: Database + + def __init__(self, db: Database): + self.db = db + + def create_table(self) -> None: + self.db.execute(""" + CREATE TABLE IF NOT EXISTS stop_times ( + trip_id TEXT NOT NULL, + arrival_time TEXT NOT NULL, + departure_time TEXT NOT NULL, + stop_id INTEGER NOT NULL, + stop_sequence INTEGER NOT NULL, + stop_headsign TEXT NOT NULL, + pickup_type INTEGER NOT NULL, + drop_off_type INTEGER NOT NULL, + FOREIGN KEY (stop_id) REFERENCES stops (stop_id) + ) + """) + self.db.commit() + + def insert(self, stop_time: StopTime) -> None: + self.db.execute( + """ + INSERT INTO stop_times ( + trip_id, arrival_time, departure_time, stop_id, + stop_sequence, stop_headsign, pickup_type, drop_off_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + stop_time.trip_id, + stop_time.arrival_time, + stop_time.departure_time, + stop_time.stop_id, + stop_time.stop_sequence, + stop_time.stop_headsign, + stop_time.pickup_type, + stop_time.drop_off_type, + ), + ) + self.db.commit() + + def bulk_insert(self, stop_times: list[StopTime]) -> None: + self.db.executemany( + """ + INSERT INTO stop_times ( + trip_id, arrival_time, departure_time, stop_id, + stop_sequence, stop_headsign, pickup_type, drop_off_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + st.trip_id, + st.arrival_time, + st.departure_time, + st.stop_id, + st.stop_sequence, + st.stop_headsign, + st.pickup_type, + st.drop_off_type, + ) + for st in stop_times + ], + ) + self.db.commit() + + def load_from_csv(self, csv_path: Path) -> None: + stop_times: list[StopTime] = [] + + with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + try: + stop_time = StopTime.from_csv_row(row) + stop_times.append(stop_time) + except (ValueError, KeyError) as e: + print(f"Error processing row: {row}. Error: {e}") + + if stop_times: + self.bulk_insert(stop_times) + + def get_by_trip_id(self, trip_id: str) -> List[StopTime]: + cursor = self.db.execute( + "SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence", + (trip_id,) + ) + return [StopTime.from_row(row) for row in cursor] + + def get_by_stop_id(self, stop_id: int) -> List[StopTime]: + cursor = self.db.execute( + "SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time", + (stop_id,) + ) + return [StopTime.from_row(row) for row in cursor] + + def is_empty(self) -> bool: + cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stop_times LIMIT 1)") + result = cursor.fetchone() + return not result[0] + +class TripRepository: + db: Database + + def __init__(self, db: Database): + self.db = db + + def create_table(self) -> None: + self.db.execute(""" + CREATE TABLE IF NOT EXISTS trips ( + route_id TEXT NOT NULL, + service_id TEXT NOT NULL, + trip_id TEXT PRIMARY KEY, + trip_headsign TEXT NOT NULL, + direction_id INTEGER NOT NULL, + shape_id INTEGER NOT NULL, + wheelchair_accessible INTEGER NOT NULL, + brigade INTEGER NOT NULL + ) + """) + self.db.commit() + + def insert(self, trip: Trip) -> None: + self.db.execute( + """ + INSERT INTO trips ( + route_id, service_id, trip_id, trip_headsign, + direction_id, shape_id, wheelchair_accessible, brigade + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + trip.route_id, + trip.service_id, + trip.trip_id, + trip.trip_headsign, + trip.direction_id, + trip.shape_id, + trip.wheelchair_accessible, + trip.brigade, + ), + ) + self.db.commit() + + def bulk_insert(self, trips: list[Trip]) -> None: + self.db.executemany( + """ + INSERT INTO trips ( + route_id, service_id, trip_id, trip_headsign, + direction_id, shape_id, wheelchair_accessible, brigade + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + t.route_id, + t.service_id, + t.trip_id, + t.trip_headsign, + t.direction_id, + t.shape_id, + t.wheelchair_accessible, + t.brigade, + ) + for t in trips + ], + ) + self.db.commit() + + def load_from_csv(self, csv_path: Path) -> None: + trips: list[Trip] = [] + + with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + try: + trip = Trip.from_csv_row(row) + trips.append(trip) + except (ValueError, KeyError) as e: + print(f"Error processing row: {row}. Error: {e}") + + if trips: + self.bulk_insert(trips) + + def get_by_id(self, trip_id: str) -> Optional[Trip]: + cursor = self.db.execute("SELECT * FROM trips WHERE trip_id = ?", (trip_id,)) + row = cursor.fetchone() + return Trip.from_row(row) if row else None + + def get_by_route_id(self, route_id: str) -> List[Trip]: + cursor = self.db.execute("SELECT * FROM trips WHERE route_id = ?", (route_id,)) + return [Trip.from_row(row) for row in cursor] + + def get_by_headsign(self, headsign: str) -> List[Trip]: + cursor = self.db.execute( + "SELECT * FROM trips WHERE trip_headsign LIKE ?", + (f"%{headsign}%",) + ) + return [Trip.from_row(row) for row in cursor] + + def is_empty(self) -> bool: + cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM trips LIMIT 1)") + result = cursor.fetchone() + return not result[0] + +def download_data(): + response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile") + open("data.zip", "wb").write(response.content) + with zipfile.ZipFile("data.zip", "r") as zip_ref: + zip_ref.extractall("tmp") + +def get_realtime(trip: Trip): + with open("trip_updates.pb", "rb") as f: + response = f.read() + feed = gtfs_realtime_pb2.FeedMessage() # type: ignore + feed.ParseFromString(response) + for entity in feed.entity: + if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id): + print(entity.id) + print(entity.trip_update) + +def main() -> None: + with Database("transit.db") as db: + stop_repo = StopRepository(db) + stop_time_repo = StopTimeRepository(db) + trip_repo = TripRepository(db) + + # Create tables + stop_repo.create_table() + stop_time_repo.create_table() + trip_repo.create_table() + + # Load data if tables are empty + if stop_repo.is_empty(): + stop_repo.load_from_csv(Path("tmp/stops.txt")) + if stop_time_repo.is_empty(): + stop_time_repo.load_from_csv(Path("tmp/stop_times.txt")) + if trip_repo.is_empty(): + trip_repo.load_from_csv(Path("tmp/trips.txt")) + + stop = stop_repo.find_by_name("Polanka")[0] + for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1): + print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}') + get_realtime(trip) + + + +if __name__ == "__main__": + main()