511 lines
16 KiB
Python
Executable File
511 lines
16 KiB
Python
Executable File
import zipfile
|
|
import requests
|
|
from google.transit import gtfs_realtime_pb2
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Iterator, Tuple
|
|
import sqlite3
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
|
|
@dataclass
|
|
class Stop:
|
|
stop_id: int
|
|
stop_code: str
|
|
stop_name: str
|
|
stop_lat: float
|
|
stop_lon: float
|
|
zone_id: str
|
|
|
|
@classmethod
|
|
def from_row(cls, row: sqlite3.Row) -> "Stop":
|
|
return cls(
|
|
stop_id=row["stop_id"],
|
|
stop_code=row["stop_code"],
|
|
stop_name=row["stop_name"],
|
|
stop_lat=row["stop_lat"],
|
|
stop_lon=row["stop_lon"],
|
|
zone_id=row["zone_id"],
|
|
)
|
|
|
|
@classmethod
|
|
def from_csv_row(cls, row) -> "Stop":
|
|
return cls(
|
|
stop_id=int(row["stop_id"]),
|
|
stop_code=row["stop_code"].strip('"'),
|
|
stop_name=row["stop_name"].strip('"'),
|
|
stop_lat=float(row["stop_lat"]),
|
|
stop_lon=float(row["stop_lon"]),
|
|
zone_id=row["zone_id"],
|
|
)
|
|
|
|
@dataclass
|
|
class Trip:
|
|
route_id: str
|
|
service_id: str
|
|
trip_id: str
|
|
trip_headsign: str
|
|
direction_id: int
|
|
shape_id: int
|
|
wheelchair_accessible: int
|
|
brigade: int
|
|
|
|
@classmethod
|
|
def from_row(cls, row: sqlite3.Row) -> "Trip":
|
|
return cls(
|
|
route_id=row["route_id"],
|
|
service_id=row["service_id"],
|
|
trip_id=row["trip_id"],
|
|
trip_headsign=row["trip_headsign"],
|
|
direction_id=row["direction_id"],
|
|
shape_id=row["shape_id"],
|
|
wheelchair_accessible=row["wheelchair_accessible"],
|
|
brigade=row["brigade"],
|
|
)
|
|
|
|
@classmethod
|
|
def from_csv_row(cls, row) -> "Trip":
|
|
return cls(
|
|
route_id=row["route_id"].strip('"'),
|
|
service_id=row["service_id"].strip('"'),
|
|
trip_id=row["trip_id"].strip('"'),
|
|
trip_headsign=row["trip_headsign"].strip('"'),
|
|
direction_id=int(row["direction_id"]),
|
|
shape_id=int(row["shape_id"]),
|
|
wheelchair_accessible=int(row["wheelchair_accessible"]),
|
|
brigade=int(row["brigade"]),
|
|
)
|
|
|
|
@dataclass
|
|
class StopTime:
|
|
trip_id: str
|
|
arrival_time: str # Using str for now since SQLite doesn't have a TIME type
|
|
departure_time: str
|
|
stop_id: int
|
|
stop_sequence: int
|
|
stop_headsign: str
|
|
pickup_type: int
|
|
drop_off_type: int
|
|
|
|
@classmethod
|
|
def from_row(cls, row: sqlite3.Row) -> "StopTime":
|
|
return cls(
|
|
trip_id=row["trip_id"],
|
|
arrival_time=row["arrival_time"],
|
|
departure_time=row["departure_time"],
|
|
stop_id=row["stop_id"],
|
|
stop_sequence=row["stop_sequence"],
|
|
stop_headsign=row["stop_headsign"],
|
|
pickup_type=row["pickup_type"],
|
|
drop_off_type=row["drop_off_type"],
|
|
)
|
|
|
|
@classmethod
|
|
def from_csv_row(cls, row) -> "StopTime":
|
|
return cls(
|
|
trip_id=row["trip_id"].strip('"'),
|
|
arrival_time=row["arrival_time"],
|
|
departure_time=row["departure_time"],
|
|
stop_id=int(row["stop_id"]),
|
|
stop_sequence=int(row["stop_sequence"]),
|
|
stop_headsign=row["stop_headsign"].strip('"'),
|
|
pickup_type=int(row["pickup_type"]),
|
|
drop_off_type=int(row["drop_off_type"]),
|
|
)
|
|
|
|
# 2. Create a typed database wrapper
|
|
class Database:
|
|
conn: sqlite3.Connection
|
|
|
|
def __init__(self, db_path: str):
|
|
self.db_path = db_path
|
|
|
|
def __enter__(self) -> "Database":
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.conn.row_factory = sqlite3.Row
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
self.conn.close()
|
|
|
|
def execute(self, query: str, params={}) -> sqlite3.Cursor:
|
|
return self.conn.execute(query, params)
|
|
|
|
def executemany(self, query: str, params_list=[]) -> sqlite3.Cursor:
|
|
return self.conn.executemany(query, params_list)
|
|
|
|
def commit(self) -> None:
|
|
self.conn.commit()
|
|
|
|
|
|
# 3. Create a typed repository for stops
|
|
class StopRepository:
|
|
db: Database
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
|
|
def create_table(self) -> None:
|
|
self.db.execute("""
|
|
CREATE TABLE IF NOT EXISTS stops (
|
|
stop_id INTEGER PRIMARY KEY,
|
|
stop_code TEXT NOT NULL,
|
|
stop_name TEXT NOT NULL,
|
|
stop_lat REAL NOT NULL,
|
|
stop_lon REAL NOT NULL,
|
|
zone_id TEXT NOT NULL
|
|
)
|
|
""")
|
|
self.db.commit()
|
|
|
|
def insert(self, stop: Stop) -> None:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
stop.stop_id,
|
|
stop.stop_code,
|
|
stop.stop_name,
|
|
stop.stop_lat,
|
|
stop.stop_lon,
|
|
stop.zone_id,
|
|
),
|
|
)
|
|
self.db.commit()
|
|
|
|
def bulk_insert(self, stops: list[Stop]) -> None:
|
|
"""Insert multiple stops efficiently."""
|
|
self.db.executemany(
|
|
"""
|
|
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
[
|
|
(s.stop_id, s.stop_code, s.stop_name, s.stop_lat, s.stop_lon, s.zone_id)
|
|
for s in stops
|
|
],
|
|
)
|
|
self.db.commit()
|
|
|
|
def load_from_csv(self, csv_path: Path) -> None:
|
|
"""Load stops from a CSV file."""
|
|
stops: list[Stop] = []
|
|
|
|
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
for row in reader:
|
|
try:
|
|
stop = Stop.from_csv_row(row)
|
|
stops.append(stop)
|
|
except (ValueError, KeyError) as e:
|
|
print(f"Error processing row: {row}. Error: {e}")
|
|
|
|
if stops:
|
|
self.bulk_insert(stops)
|
|
|
|
def get_by_id(self, stop_id: int) -> Optional[Stop]:
|
|
cursor = self.db.execute("SELECT * FROM stops WHERE stop_id = ?", (stop_id,))
|
|
row = cursor.fetchone()
|
|
return Stop.from_row(row) if row else None
|
|
|
|
def get_all(self) -> Iterator[Stop]:
|
|
cursor = self.db.execute("SELECT * FROM stops")
|
|
for row in cursor:
|
|
yield Stop.from_row(row)
|
|
|
|
def find_by_name(self, name: str) -> List[Stop]:
|
|
cursor = self.db.execute(
|
|
"SELECT * FROM stops WHERE stop_name LIKE ?", (f"%{name}%",)
|
|
)
|
|
return [Stop.from_row(row) for row in cursor]
|
|
|
|
def is_empty(self) -> bool:
|
|
"""Check if the stops table is empty."""
|
|
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stops LIMIT 1)")
|
|
result = cursor.fetchone()
|
|
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
|
|
|
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
|
|
"""Get upcoming departures from a stop with trip information."""
|
|
cursor = self.db.execute("""
|
|
SELECT st.*, t.*
|
|
FROM stop_times st
|
|
JOIN trips t ON st.trip_id = t.trip_id
|
|
WHERE st.stop_id = ?
|
|
ORDER BY st.departure_time
|
|
LIMIT ?
|
|
""", (stop_id, limit))
|
|
|
|
departures = []
|
|
for row in cursor:
|
|
stop_time = StopTime.from_row(row)
|
|
trip = Trip(
|
|
route_id=row['route_id'],
|
|
service_id=row['service_id'],
|
|
trip_id=row['trip_id'],
|
|
trip_headsign=row['trip_headsign'],
|
|
direction_id=row['direction_id'],
|
|
shape_id=row['shape_id'],
|
|
wheelchair_accessible=row['wheelchair_accessible'],
|
|
brigade=row['brigade']
|
|
)
|
|
departures.append((stop_time, trip))
|
|
|
|
return departures
|
|
|
|
|
|
class StopTimeRepository:
|
|
db: Database
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
|
|
def create_table(self) -> None:
|
|
self.db.execute("""
|
|
CREATE TABLE IF NOT EXISTS stop_times (
|
|
trip_id TEXT NOT NULL,
|
|
arrival_time TEXT NOT NULL,
|
|
departure_time TEXT NOT NULL,
|
|
stop_id INTEGER NOT NULL,
|
|
stop_sequence INTEGER NOT NULL,
|
|
stop_headsign TEXT NOT NULL,
|
|
pickup_type INTEGER NOT NULL,
|
|
drop_off_type INTEGER NOT NULL,
|
|
FOREIGN KEY (stop_id) REFERENCES stops (stop_id)
|
|
)
|
|
""")
|
|
self.db.commit()
|
|
|
|
def insert(self, stop_time: StopTime) -> None:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO stop_times (
|
|
trip_id, arrival_time, departure_time, stop_id,
|
|
stop_sequence, stop_headsign, pickup_type, drop_off_type
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
stop_time.trip_id,
|
|
stop_time.arrival_time,
|
|
stop_time.departure_time,
|
|
stop_time.stop_id,
|
|
stop_time.stop_sequence,
|
|
stop_time.stop_headsign,
|
|
stop_time.pickup_type,
|
|
stop_time.drop_off_type,
|
|
),
|
|
)
|
|
self.db.commit()
|
|
|
|
def bulk_insert(self, stop_times: list[StopTime]) -> None:
|
|
self.db.executemany(
|
|
"""
|
|
INSERT INTO stop_times (
|
|
trip_id, arrival_time, departure_time, stop_id,
|
|
stop_sequence, stop_headsign, pickup_type, drop_off_type
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
[
|
|
(
|
|
st.trip_id,
|
|
st.arrival_time,
|
|
st.departure_time,
|
|
st.stop_id,
|
|
st.stop_sequence,
|
|
st.stop_headsign,
|
|
st.pickup_type,
|
|
st.drop_off_type,
|
|
)
|
|
for st in stop_times
|
|
],
|
|
)
|
|
self.db.commit()
|
|
|
|
def load_from_csv(self, csv_path: Path) -> None:
|
|
stop_times: list[StopTime] = []
|
|
|
|
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
for row in reader:
|
|
try:
|
|
stop_time = StopTime.from_csv_row(row)
|
|
stop_times.append(stop_time)
|
|
except (ValueError, KeyError) as e:
|
|
print(f"Error processing row: {row}. Error: {e}")
|
|
|
|
if stop_times:
|
|
self.bulk_insert(stop_times)
|
|
|
|
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
|
|
cursor = self.db.execute(
|
|
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
|
|
(trip_id,)
|
|
)
|
|
return [StopTime.from_row(row) for row in cursor]
|
|
|
|
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
|
|
cursor = self.db.execute(
|
|
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
|
|
(stop_id,)
|
|
)
|
|
return [StopTime.from_row(row) for row in cursor]
|
|
|
|
def is_empty(self) -> bool:
|
|
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stop_times LIMIT 1)")
|
|
result = cursor.fetchone()
|
|
return not result[0]
|
|
|
|
class TripRepository:
|
|
db: Database
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
|
|
def create_table(self) -> None:
|
|
self.db.execute("""
|
|
CREATE TABLE IF NOT EXISTS trips (
|
|
route_id TEXT NOT NULL,
|
|
service_id TEXT NOT NULL,
|
|
trip_id TEXT PRIMARY KEY,
|
|
trip_headsign TEXT NOT NULL,
|
|
direction_id INTEGER NOT NULL,
|
|
shape_id INTEGER NOT NULL,
|
|
wheelchair_accessible INTEGER NOT NULL,
|
|
brigade INTEGER NOT NULL
|
|
)
|
|
""")
|
|
self.db.commit()
|
|
|
|
def insert(self, trip: Trip) -> None:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO trips (
|
|
route_id, service_id, trip_id, trip_headsign,
|
|
direction_id, shape_id, wheelchair_accessible, brigade
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
trip.route_id,
|
|
trip.service_id,
|
|
trip.trip_id,
|
|
trip.trip_headsign,
|
|
trip.direction_id,
|
|
trip.shape_id,
|
|
trip.wheelchair_accessible,
|
|
trip.brigade,
|
|
),
|
|
)
|
|
self.db.commit()
|
|
|
|
def bulk_insert(self, trips: list[Trip]) -> None:
|
|
self.db.executemany(
|
|
"""
|
|
INSERT INTO trips (
|
|
route_id, service_id, trip_id, trip_headsign,
|
|
direction_id, shape_id, wheelchair_accessible, brigade
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
[
|
|
(
|
|
t.route_id,
|
|
t.service_id,
|
|
t.trip_id,
|
|
t.trip_headsign,
|
|
t.direction_id,
|
|
t.shape_id,
|
|
t.wheelchair_accessible,
|
|
t.brigade,
|
|
)
|
|
for t in trips
|
|
],
|
|
)
|
|
self.db.commit()
|
|
|
|
def load_from_csv(self, csv_path: Path) -> None:
|
|
trips: list[Trip] = []
|
|
|
|
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
for row in reader:
|
|
try:
|
|
trip = Trip.from_csv_row(row)
|
|
trips.append(trip)
|
|
except (ValueError, KeyError) as e:
|
|
print(f"Error processing row: {row}. Error: {e}")
|
|
|
|
if trips:
|
|
self.bulk_insert(trips)
|
|
|
|
def get_by_id(self, trip_id: str) -> Optional[Trip]:
|
|
cursor = self.db.execute("SELECT * FROM trips WHERE trip_id = ?", (trip_id,))
|
|
row = cursor.fetchone()
|
|
return Trip.from_row(row) if row else None
|
|
|
|
def get_by_route_id(self, route_id: str) -> List[Trip]:
|
|
cursor = self.db.execute("SELECT * FROM trips WHERE route_id = ?", (route_id,))
|
|
return [Trip.from_row(row) for row in cursor]
|
|
|
|
def get_by_headsign(self, headsign: str) -> List[Trip]:
|
|
cursor = self.db.execute(
|
|
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
|
|
(f"%{headsign}%",)
|
|
)
|
|
return [Trip.from_row(row) for row in cursor]
|
|
|
|
def is_empty(self) -> bool:
|
|
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM trips LIMIT 1)")
|
|
result = cursor.fetchone()
|
|
return not result[0]
|
|
|
|
def download_data():
|
|
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
|
open("data.zip", "wb").write(response.content)
|
|
with zipfile.ZipFile("data.zip", "r") as zip_ref:
|
|
zip_ref.extractall("tmp")
|
|
|
|
def get_realtime(trip: Trip):
|
|
with open("trip_updates.pb", "rb") as f:
|
|
response = f.read()
|
|
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
|
feed.ParseFromString(response)
|
|
for entity in feed.entity:
|
|
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id):
|
|
print(entity.id)
|
|
print(entity.trip_update)
|
|
|
|
def main() -> None:
|
|
with Database("transit.db") as db:
|
|
stop_repo = StopRepository(db)
|
|
stop_time_repo = StopTimeRepository(db)
|
|
trip_repo = TripRepository(db)
|
|
|
|
# Create tables
|
|
stop_repo.create_table()
|
|
stop_time_repo.create_table()
|
|
trip_repo.create_table()
|
|
|
|
# Load data if tables are empty
|
|
if stop_repo.is_empty():
|
|
stop_repo.load_from_csv(Path("tmp/stops.txt"))
|
|
if stop_time_repo.is_empty():
|
|
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt"))
|
|
if trip_repo.is_empty():
|
|
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
|
|
|
stop = stop_repo.find_by_name("Polanka")[0]
|
|
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
|
|
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
|
|
get_realtime(trip)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|