wip data model

This commit is contained in:
Dawid Pietrykowski 2025-02-14 01:02:32 +01:00
parent f1b75e7dae
commit 947fa37aac
3 changed files with 515 additions and 21 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
tmp
proto
*.txt
*.pb
*.db
**/__pycache__
uv.lock

View File

@ -3,6 +3,6 @@
##### ZTM API docs
https://www.ztm.poznan.pl/otwarte-dane/dla-deweloperow/
##### ZTM API - static data (routes)
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.pb
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.zip
##### ZTM API - realtime data (delays)
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb' -o trip_updates.pb

View File

@ -1,23 +1,510 @@
from google.transit import gtfs_realtime_pb2
import zipfile
import requests
from google.transit import gtfs_realtime_pb2
feed = gtfs_realtime_pb2.FeedMessage()
# response = requests.get('https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb')
# response = ''
# with open('trip_updates.pb', 'rb') as f:
# response = f.read()
# feed.ParseFromString(response)
# for entity in feed.entity:
# print(entity)
# if entity.HasField('trip_update'):
# print(entity.trip_update)
from dataclasses import dataclass
from typing import List, Optional, Iterator, Tuple
import sqlite3
import csv
from pathlib import Path
with open('feeds.pb', 'rb') as f:
response = f.read()
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(response)
for entity in feed.entity:
print(entity)
# if entity.HasField('trip_update'):
# print(entity.trip_update)
@dataclass
class Stop:
stop_id: int
stop_code: str
stop_name: str
stop_lat: float
stop_lon: float
zone_id: str
@classmethod
def from_row(cls, row: sqlite3.Row) -> "Stop":
return cls(
stop_id=row["stop_id"],
stop_code=row["stop_code"],
stop_name=row["stop_name"],
stop_lat=row["stop_lat"],
stop_lon=row["stop_lon"],
zone_id=row["zone_id"],
)
@classmethod
def from_csv_row(cls, row) -> "Stop":
return cls(
stop_id=int(row["stop_id"]),
stop_code=row["stop_code"].strip('"'),
stop_name=row["stop_name"].strip('"'),
stop_lat=float(row["stop_lat"]),
stop_lon=float(row["stop_lon"]),
zone_id=row["zone_id"],
)
@dataclass
class Trip:
route_id: str
service_id: str
trip_id: str
trip_headsign: str
direction_id: int
shape_id: int
wheelchair_accessible: int
brigade: int
@classmethod
def from_row(cls, row: sqlite3.Row) -> "Trip":
return cls(
route_id=row["route_id"],
service_id=row["service_id"],
trip_id=row["trip_id"],
trip_headsign=row["trip_headsign"],
direction_id=row["direction_id"],
shape_id=row["shape_id"],
wheelchair_accessible=row["wheelchair_accessible"],
brigade=row["brigade"],
)
@classmethod
def from_csv_row(cls, row) -> "Trip":
return cls(
route_id=row["route_id"].strip('"'),
service_id=row["service_id"].strip('"'),
trip_id=row["trip_id"].strip('"'),
trip_headsign=row["trip_headsign"].strip('"'),
direction_id=int(row["direction_id"]),
shape_id=int(row["shape_id"]),
wheelchair_accessible=int(row["wheelchair_accessible"]),
brigade=int(row["brigade"]),
)
@dataclass
class StopTime:
trip_id: str
arrival_time: str # Using str for now since SQLite doesn't have a TIME type
departure_time: str
stop_id: int
stop_sequence: int
stop_headsign: str
pickup_type: int
drop_off_type: int
@classmethod
def from_row(cls, row: sqlite3.Row) -> "StopTime":
return cls(
trip_id=row["trip_id"],
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
stop_id=row["stop_id"],
stop_sequence=row["stop_sequence"],
stop_headsign=row["stop_headsign"],
pickup_type=row["pickup_type"],
drop_off_type=row["drop_off_type"],
)
@classmethod
def from_csv_row(cls, row) -> "StopTime":
return cls(
trip_id=row["trip_id"].strip('"'),
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
stop_id=int(row["stop_id"]),
stop_sequence=int(row["stop_sequence"]),
stop_headsign=row["stop_headsign"].strip('"'),
pickup_type=int(row["pickup_type"]),
drop_off_type=int(row["drop_off_type"]),
)
# 2. Create a typed database wrapper
class Database:
conn: sqlite3.Connection
def __init__(self, db_path: str):
self.db_path = db_path
def __enter__(self) -> "Database":
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.conn.close()
def execute(self, query: str, params={}) -> sqlite3.Cursor:
return self.conn.execute(query, params)
def executemany(self, query: str, params_list=[]) -> sqlite3.Cursor:
return self.conn.executemany(query, params_list)
def commit(self) -> None:
self.conn.commit()
# 3. Create a typed repository for stops
class StopRepository:
db: Database
def __init__(self, db: Database):
self.db = db
def create_table(self) -> None:
self.db.execute("""
CREATE TABLE IF NOT EXISTS stops (
stop_id INTEGER PRIMARY KEY,
stop_code TEXT NOT NULL,
stop_name TEXT NOT NULL,
stop_lat REAL NOT NULL,
stop_lon REAL NOT NULL,
zone_id TEXT NOT NULL
)
""")
self.db.commit()
def insert(self, stop: Stop) -> None:
self.db.execute(
"""
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
stop.stop_id,
stop.stop_code,
stop.stop_name,
stop.stop_lat,
stop.stop_lon,
stop.zone_id,
),
)
self.db.commit()
def bulk_insert(self, stops: list[Stop]) -> None:
"""Insert multiple stops efficiently."""
self.db.executemany(
"""
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
VALUES (?, ?, ?, ?, ?, ?)
""",
[
(s.stop_id, s.stop_code, s.stop_name, s.stop_lat, s.stop_lon, s.zone_id)
for s in stops
],
)
self.db.commit()
def load_from_csv(self, csv_path: Path) -> None:
"""Load stops from a CSV file."""
stops: list[Stop] = []
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
try:
stop = Stop.from_csv_row(row)
stops.append(stop)
except (ValueError, KeyError) as e:
print(f"Error processing row: {row}. Error: {e}")
if stops:
self.bulk_insert(stops)
def get_by_id(self, stop_id: int) -> Optional[Stop]:
cursor = self.db.execute("SELECT * FROM stops WHERE stop_id = ?", (stop_id,))
row = cursor.fetchone()
return Stop.from_row(row) if row else None
def get_all(self) -> Iterator[Stop]:
cursor = self.db.execute("SELECT * FROM stops")
for row in cursor:
yield Stop.from_row(row)
def find_by_name(self, name: str) -> List[Stop]:
cursor = self.db.execute(
"SELECT * FROM stops WHERE stop_name LIKE ?", (f"%{name}%",)
)
return [Stop.from_row(row) for row in cursor]
def is_empty(self) -> bool:
"""Check if the stops table is empty."""
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stops LIMIT 1)")
result = cursor.fetchone()
return not result[0] # SQLite returns 0 if empty, 1 if not empty
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
"""Get upcoming departures from a stop with trip information."""
cursor = self.db.execute("""
SELECT st.*, t.*
FROM stop_times st
JOIN trips t ON st.trip_id = t.trip_id
WHERE st.stop_id = ?
ORDER BY st.departure_time
LIMIT ?
""", (stop_id, limit))
departures = []
for row in cursor:
stop_time = StopTime.from_row(row)
trip = Trip(
route_id=row['route_id'],
service_id=row['service_id'],
trip_id=row['trip_id'],
trip_headsign=row['trip_headsign'],
direction_id=row['direction_id'],
shape_id=row['shape_id'],
wheelchair_accessible=row['wheelchair_accessible'],
brigade=row['brigade']
)
departures.append((stop_time, trip))
return departures
class StopTimeRepository:
db: Database
def __init__(self, db: Database):
self.db = db
def create_table(self) -> None:
self.db.execute("""
CREATE TABLE IF NOT EXISTS stop_times (
trip_id TEXT NOT NULL,
arrival_time TEXT NOT NULL,
departure_time TEXT NOT NULL,
stop_id INTEGER NOT NULL,
stop_sequence INTEGER NOT NULL,
stop_headsign TEXT NOT NULL,
pickup_type INTEGER NOT NULL,
drop_off_type INTEGER NOT NULL,
FOREIGN KEY (stop_id) REFERENCES stops (stop_id)
)
""")
self.db.commit()
def insert(self, stop_time: StopTime) -> None:
self.db.execute(
"""
INSERT INTO stop_times (
trip_id, arrival_time, departure_time, stop_id,
stop_sequence, stop_headsign, pickup_type, drop_off_type
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
stop_time.trip_id,
stop_time.arrival_time,
stop_time.departure_time,
stop_time.stop_id,
stop_time.stop_sequence,
stop_time.stop_headsign,
stop_time.pickup_type,
stop_time.drop_off_type,
),
)
self.db.commit()
def bulk_insert(self, stop_times: list[StopTime]) -> None:
self.db.executemany(
"""
INSERT INTO stop_times (
trip_id, arrival_time, departure_time, stop_id,
stop_sequence, stop_headsign, pickup_type, drop_off_type
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
st.trip_id,
st.arrival_time,
st.departure_time,
st.stop_id,
st.stop_sequence,
st.stop_headsign,
st.pickup_type,
st.drop_off_type,
)
for st in stop_times
],
)
self.db.commit()
def load_from_csv(self, csv_path: Path) -> None:
stop_times: list[StopTime] = []
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
try:
stop_time = StopTime.from_csv_row(row)
stop_times.append(stop_time)
except (ValueError, KeyError) as e:
print(f"Error processing row: {row}. Error: {e}")
if stop_times:
self.bulk_insert(stop_times)
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
cursor = self.db.execute(
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
(trip_id,)
)
return [StopTime.from_row(row) for row in cursor]
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
cursor = self.db.execute(
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
(stop_id,)
)
return [StopTime.from_row(row) for row in cursor]
def is_empty(self) -> bool:
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stop_times LIMIT 1)")
result = cursor.fetchone()
return not result[0]
class TripRepository:
db: Database
def __init__(self, db: Database):
self.db = db
def create_table(self) -> None:
self.db.execute("""
CREATE TABLE IF NOT EXISTS trips (
route_id TEXT NOT NULL,
service_id TEXT NOT NULL,
trip_id TEXT PRIMARY KEY,
trip_headsign TEXT NOT NULL,
direction_id INTEGER NOT NULL,
shape_id INTEGER NOT NULL,
wheelchair_accessible INTEGER NOT NULL,
brigade INTEGER NOT NULL
)
""")
self.db.commit()
def insert(self, trip: Trip) -> None:
self.db.execute(
"""
INSERT INTO trips (
route_id, service_id, trip_id, trip_headsign,
direction_id, shape_id, wheelchair_accessible, brigade
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
trip.route_id,
trip.service_id,
trip.trip_id,
trip.trip_headsign,
trip.direction_id,
trip.shape_id,
trip.wheelchair_accessible,
trip.brigade,
),
)
self.db.commit()
def bulk_insert(self, trips: list[Trip]) -> None:
self.db.executemany(
"""
INSERT INTO trips (
route_id, service_id, trip_id, trip_headsign,
direction_id, shape_id, wheelchair_accessible, brigade
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
t.route_id,
t.service_id,
t.trip_id,
t.trip_headsign,
t.direction_id,
t.shape_id,
t.wheelchair_accessible,
t.brigade,
)
for t in trips
],
)
self.db.commit()
def load_from_csv(self, csv_path: Path) -> None:
trips: list[Trip] = []
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
try:
trip = Trip.from_csv_row(row)
trips.append(trip)
except (ValueError, KeyError) as e:
print(f"Error processing row: {row}. Error: {e}")
if trips:
self.bulk_insert(trips)
def get_by_id(self, trip_id: str) -> Optional[Trip]:
cursor = self.db.execute("SELECT * FROM trips WHERE trip_id = ?", (trip_id,))
row = cursor.fetchone()
return Trip.from_row(row) if row else None
def get_by_route_id(self, route_id: str) -> List[Trip]:
cursor = self.db.execute("SELECT * FROM trips WHERE route_id = ?", (route_id,))
return [Trip.from_row(row) for row in cursor]
def get_by_headsign(self, headsign: str) -> List[Trip]:
cursor = self.db.execute(
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
(f"%{headsign}%",)
)
return [Trip.from_row(row) for row in cursor]
def is_empty(self) -> bool:
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM trips LIMIT 1)")
result = cursor.fetchone()
return not result[0]
def download_data():
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
open("data.zip", "wb").write(response.content)
with zipfile.ZipFile("data.zip", "r") as zip_ref:
zip_ref.extractall("tmp")
def get_realtime(trip: Trip):
with open("trip_updates.pb", "rb") as f:
response = f.read()
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
feed.ParseFromString(response)
for entity in feed.entity:
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id):
print(entity.id)
print(entity.trip_update)
def main() -> None:
with Database("transit.db") as db:
stop_repo = StopRepository(db)
stop_time_repo = StopTimeRepository(db)
trip_repo = TripRepository(db)
# Create tables
stop_repo.create_table()
stop_time_repo.create_table()
trip_repo.create_table()
# Load data if tables are empty
if stop_repo.is_empty():
stop_repo.load_from_csv(Path("tmp/stops.txt"))
if stop_time_repo.is_empty():
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt"))
if trip_repo.is_empty():
trip_repo.load_from_csv(Path("tmp/trips.txt"))
stop = stop_repo.find_by_name("Polanka")[0]
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
get_realtime(trip)
if __name__ == "__main__":
main()