wip data model
This commit is contained in:
parent
f1b75e7dae
commit
947fa37aac
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
tmp
|
||||
proto
|
||||
*.txt
|
||||
*.pb
|
||||
*.db
|
||||
**/__pycache__
|
||||
uv.lock
|
@ -3,6 +3,6 @@
|
||||
##### ZTM API docs
|
||||
https://www.ztm.poznan.pl/otwarte-dane/dla-deweloperow/
|
||||
##### ZTM API - static data (routes)
|
||||
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.pb
|
||||
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile' -o routes.zip
|
||||
##### ZTM API - realtime data (delays)
|
||||
curl 'https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb' -o trip_updates.pb
|
||||
|
@ -1,23 +1,510 @@
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
import zipfile
|
||||
import requests
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
|
||||
feed = gtfs_realtime_pb2.FeedMessage()
|
||||
# response = requests.get('https://www.ztm.poznan.pl/pl/dla-deweloperow/getGtfsRtFile?file=trip_updates.pb')
|
||||
# response = ''
|
||||
# with open('trip_updates.pb', 'rb') as f:
|
||||
# response = f.read()
|
||||
# feed.ParseFromString(response)
|
||||
# for entity in feed.entity:
|
||||
# print(entity)
|
||||
# if entity.HasField('trip_update'):
|
||||
# print(entity.trip_update)
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
import sqlite3
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
with open('feeds.pb', 'rb') as f:
|
||||
response = f.read()
|
||||
feed = gtfs_realtime_pb2.FeedMessage()
|
||||
feed.ParseFromString(response)
|
||||
for entity in feed.entity:
|
||||
print(entity)
|
||||
# if entity.HasField('trip_update'):
|
||||
# print(entity.trip_update)
|
||||
|
||||
@dataclass
|
||||
class Stop:
|
||||
stop_id: int
|
||||
stop_code: str
|
||||
stop_name: str
|
||||
stop_lat: float
|
||||
stop_lon: float
|
||||
zone_id: str
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: sqlite3.Row) -> "Stop":
|
||||
return cls(
|
||||
stop_id=row["stop_id"],
|
||||
stop_code=row["stop_code"],
|
||||
stop_name=row["stop_name"],
|
||||
stop_lat=row["stop_lat"],
|
||||
stop_lon=row["stop_lon"],
|
||||
zone_id=row["zone_id"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_csv_row(cls, row) -> "Stop":
|
||||
return cls(
|
||||
stop_id=int(row["stop_id"]),
|
||||
stop_code=row["stop_code"].strip('"'),
|
||||
stop_name=row["stop_name"].strip('"'),
|
||||
stop_lat=float(row["stop_lat"]),
|
||||
stop_lon=float(row["stop_lon"]),
|
||||
zone_id=row["zone_id"],
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Trip:
|
||||
route_id: str
|
||||
service_id: str
|
||||
trip_id: str
|
||||
trip_headsign: str
|
||||
direction_id: int
|
||||
shape_id: int
|
||||
wheelchair_accessible: int
|
||||
brigade: int
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: sqlite3.Row) -> "Trip":
|
||||
return cls(
|
||||
route_id=row["route_id"],
|
||||
service_id=row["service_id"],
|
||||
trip_id=row["trip_id"],
|
||||
trip_headsign=row["trip_headsign"],
|
||||
direction_id=row["direction_id"],
|
||||
shape_id=row["shape_id"],
|
||||
wheelchair_accessible=row["wheelchair_accessible"],
|
||||
brigade=row["brigade"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_csv_row(cls, row) -> "Trip":
|
||||
return cls(
|
||||
route_id=row["route_id"].strip('"'),
|
||||
service_id=row["service_id"].strip('"'),
|
||||
trip_id=row["trip_id"].strip('"'),
|
||||
trip_headsign=row["trip_headsign"].strip('"'),
|
||||
direction_id=int(row["direction_id"]),
|
||||
shape_id=int(row["shape_id"]),
|
||||
wheelchair_accessible=int(row["wheelchair_accessible"]),
|
||||
brigade=int(row["brigade"]),
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class StopTime:
|
||||
trip_id: str
|
||||
arrival_time: str # Using str for now since SQLite doesn't have a TIME type
|
||||
departure_time: str
|
||||
stop_id: int
|
||||
stop_sequence: int
|
||||
stop_headsign: str
|
||||
pickup_type: int
|
||||
drop_off_type: int
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: sqlite3.Row) -> "StopTime":
|
||||
return cls(
|
||||
trip_id=row["trip_id"],
|
||||
arrival_time=row["arrival_time"],
|
||||
departure_time=row["departure_time"],
|
||||
stop_id=row["stop_id"],
|
||||
stop_sequence=row["stop_sequence"],
|
||||
stop_headsign=row["stop_headsign"],
|
||||
pickup_type=row["pickup_type"],
|
||||
drop_off_type=row["drop_off_type"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_csv_row(cls, row) -> "StopTime":
|
||||
return cls(
|
||||
trip_id=row["trip_id"].strip('"'),
|
||||
arrival_time=row["arrival_time"],
|
||||
departure_time=row["departure_time"],
|
||||
stop_id=int(row["stop_id"]),
|
||||
stop_sequence=int(row["stop_sequence"]),
|
||||
stop_headsign=row["stop_headsign"].strip('"'),
|
||||
pickup_type=int(row["pickup_type"]),
|
||||
drop_off_type=int(row["drop_off_type"]),
|
||||
)
|
||||
|
||||
# 2. Create a typed database wrapper
|
||||
class Database:
|
||||
conn: sqlite3.Connection
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
|
||||
def __enter__(self) -> "Database":
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.conn.close()
|
||||
|
||||
def execute(self, query: str, params={}) -> sqlite3.Cursor:
|
||||
return self.conn.execute(query, params)
|
||||
|
||||
def executemany(self, query: str, params_list=[]) -> sqlite3.Cursor:
|
||||
return self.conn.executemany(query, params_list)
|
||||
|
||||
def commit(self) -> None:
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
# 3. Create a typed repository for stops
|
||||
class StopRepository:
|
||||
db: Database
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
def create_table(self) -> None:
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stops (
|
||||
stop_id INTEGER PRIMARY KEY,
|
||||
stop_code TEXT NOT NULL,
|
||||
stop_name TEXT NOT NULL,
|
||||
stop_lat REAL NOT NULL,
|
||||
stop_lon REAL NOT NULL,
|
||||
zone_id TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
self.db.commit()
|
||||
|
||||
def insert(self, stop: Stop) -> None:
|
||||
self.db.execute(
|
||||
"""
|
||||
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
stop.stop_id,
|
||||
stop.stop_code,
|
||||
stop.stop_name,
|
||||
stop.stop_lat,
|
||||
stop.stop_lon,
|
||||
stop.zone_id,
|
||||
),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def bulk_insert(self, stops: list[Stop]) -> None:
|
||||
"""Insert multiple stops efficiently."""
|
||||
self.db.executemany(
|
||||
"""
|
||||
INSERT INTO stops (stop_id, stop_code, stop_name, stop_lat, stop_lon, zone_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(s.stop_id, s.stop_code, s.stop_name, s.stop_lat, s.stop_lon, s.zone_id)
|
||||
for s in stops
|
||||
],
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def load_from_csv(self, csv_path: Path) -> None:
|
||||
"""Load stops from a CSV file."""
|
||||
stops: list[Stop] = []
|
||||
|
||||
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
try:
|
||||
stop = Stop.from_csv_row(row)
|
||||
stops.append(stop)
|
||||
except (ValueError, KeyError) as e:
|
||||
print(f"Error processing row: {row}. Error: {e}")
|
||||
|
||||
if stops:
|
||||
self.bulk_insert(stops)
|
||||
|
||||
def get_by_id(self, stop_id: int) -> Optional[Stop]:
|
||||
cursor = self.db.execute("SELECT * FROM stops WHERE stop_id = ?", (stop_id,))
|
||||
row = cursor.fetchone()
|
||||
return Stop.from_row(row) if row else None
|
||||
|
||||
def get_all(self) -> Iterator[Stop]:
|
||||
cursor = self.db.execute("SELECT * FROM stops")
|
||||
for row in cursor:
|
||||
yield Stop.from_row(row)
|
||||
|
||||
def find_by_name(self, name: str) -> List[Stop]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM stops WHERE stop_name LIKE ?", (f"%{name}%",)
|
||||
)
|
||||
return [Stop.from_row(row) for row in cursor]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the stops table is empty."""
|
||||
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stops LIMIT 1)")
|
||||
result = cursor.fetchone()
|
||||
return not result[0] # SQLite returns 0 if empty, 1 if not empty
|
||||
|
||||
def get_upcoming_departures(self, stop_id: int, limit: int = 10) -> List[Tuple[StopTime, Trip]]:
|
||||
"""Get upcoming departures from a stop with trip information."""
|
||||
cursor = self.db.execute("""
|
||||
SELECT st.*, t.*
|
||||
FROM stop_times st
|
||||
JOIN trips t ON st.trip_id = t.trip_id
|
||||
WHERE st.stop_id = ?
|
||||
ORDER BY st.departure_time
|
||||
LIMIT ?
|
||||
""", (stop_id, limit))
|
||||
|
||||
departures = []
|
||||
for row in cursor:
|
||||
stop_time = StopTime.from_row(row)
|
||||
trip = Trip(
|
||||
route_id=row['route_id'],
|
||||
service_id=row['service_id'],
|
||||
trip_id=row['trip_id'],
|
||||
trip_headsign=row['trip_headsign'],
|
||||
direction_id=row['direction_id'],
|
||||
shape_id=row['shape_id'],
|
||||
wheelchair_accessible=row['wheelchair_accessible'],
|
||||
brigade=row['brigade']
|
||||
)
|
||||
departures.append((stop_time, trip))
|
||||
|
||||
return departures
|
||||
|
||||
|
||||
class StopTimeRepository:
|
||||
db: Database
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
def create_table(self) -> None:
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stop_times (
|
||||
trip_id TEXT NOT NULL,
|
||||
arrival_time TEXT NOT NULL,
|
||||
departure_time TEXT NOT NULL,
|
||||
stop_id INTEGER NOT NULL,
|
||||
stop_sequence INTEGER NOT NULL,
|
||||
stop_headsign TEXT NOT NULL,
|
||||
pickup_type INTEGER NOT NULL,
|
||||
drop_off_type INTEGER NOT NULL,
|
||||
FOREIGN KEY (stop_id) REFERENCES stops (stop_id)
|
||||
)
|
||||
""")
|
||||
self.db.commit()
|
||||
|
||||
def insert(self, stop_time: StopTime) -> None:
|
||||
self.db.execute(
|
||||
"""
|
||||
INSERT INTO stop_times (
|
||||
trip_id, arrival_time, departure_time, stop_id,
|
||||
stop_sequence, stop_headsign, pickup_type, drop_off_type
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
stop_time.trip_id,
|
||||
stop_time.arrival_time,
|
||||
stop_time.departure_time,
|
||||
stop_time.stop_id,
|
||||
stop_time.stop_sequence,
|
||||
stop_time.stop_headsign,
|
||||
stop_time.pickup_type,
|
||||
stop_time.drop_off_type,
|
||||
),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def bulk_insert(self, stop_times: list[StopTime]) -> None:
|
||||
self.db.executemany(
|
||||
"""
|
||||
INSERT INTO stop_times (
|
||||
trip_id, arrival_time, departure_time, stop_id,
|
||||
stop_sequence, stop_headsign, pickup_type, drop_off_type
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
st.trip_id,
|
||||
st.arrival_time,
|
||||
st.departure_time,
|
||||
st.stop_id,
|
||||
st.stop_sequence,
|
||||
st.stop_headsign,
|
||||
st.pickup_type,
|
||||
st.drop_off_type,
|
||||
)
|
||||
for st in stop_times
|
||||
],
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def load_from_csv(self, csv_path: Path) -> None:
|
||||
stop_times: list[StopTime] = []
|
||||
|
||||
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
try:
|
||||
stop_time = StopTime.from_csv_row(row)
|
||||
stop_times.append(stop_time)
|
||||
except (ValueError, KeyError) as e:
|
||||
print(f"Error processing row: {row}. Error: {e}")
|
||||
|
||||
if stop_times:
|
||||
self.bulk_insert(stop_times)
|
||||
|
||||
def get_by_trip_id(self, trip_id: str) -> List[StopTime]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM stop_times WHERE trip_id = ? ORDER BY stop_sequence",
|
||||
(trip_id,)
|
||||
)
|
||||
return [StopTime.from_row(row) for row in cursor]
|
||||
|
||||
def get_by_stop_id(self, stop_id: int) -> List[StopTime]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM stop_times WHERE stop_id = ? ORDER BY arrival_time",
|
||||
(stop_id,)
|
||||
)
|
||||
return [StopTime.from_row(row) for row in cursor]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM stop_times LIMIT 1)")
|
||||
result = cursor.fetchone()
|
||||
return not result[0]
|
||||
|
||||
class TripRepository:
|
||||
db: Database
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
def create_table(self) -> None:
|
||||
self.db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS trips (
|
||||
route_id TEXT NOT NULL,
|
||||
service_id TEXT NOT NULL,
|
||||
trip_id TEXT PRIMARY KEY,
|
||||
trip_headsign TEXT NOT NULL,
|
||||
direction_id INTEGER NOT NULL,
|
||||
shape_id INTEGER NOT NULL,
|
||||
wheelchair_accessible INTEGER NOT NULL,
|
||||
brigade INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
self.db.commit()
|
||||
|
||||
def insert(self, trip: Trip) -> None:
|
||||
self.db.execute(
|
||||
"""
|
||||
INSERT INTO trips (
|
||||
route_id, service_id, trip_id, trip_headsign,
|
||||
direction_id, shape_id, wheelchair_accessible, brigade
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trip.route_id,
|
||||
trip.service_id,
|
||||
trip.trip_id,
|
||||
trip.trip_headsign,
|
||||
trip.direction_id,
|
||||
trip.shape_id,
|
||||
trip.wheelchair_accessible,
|
||||
trip.brigade,
|
||||
),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def bulk_insert(self, trips: list[Trip]) -> None:
|
||||
self.db.executemany(
|
||||
"""
|
||||
INSERT INTO trips (
|
||||
route_id, service_id, trip_id, trip_headsign,
|
||||
direction_id, shape_id, wheelchair_accessible, brigade
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
t.route_id,
|
||||
t.service_id,
|
||||
t.trip_id,
|
||||
t.trip_headsign,
|
||||
t.direction_id,
|
||||
t.shape_id,
|
||||
t.wheelchair_accessible,
|
||||
t.brigade,
|
||||
)
|
||||
for t in trips
|
||||
],
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def load_from_csv(self, csv_path: Path) -> None:
|
||||
trips: list[Trip] = []
|
||||
|
||||
with open(csv_path, "r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
try:
|
||||
trip = Trip.from_csv_row(row)
|
||||
trips.append(trip)
|
||||
except (ValueError, KeyError) as e:
|
||||
print(f"Error processing row: {row}. Error: {e}")
|
||||
|
||||
if trips:
|
||||
self.bulk_insert(trips)
|
||||
|
||||
def get_by_id(self, trip_id: str) -> Optional[Trip]:
|
||||
cursor = self.db.execute("SELECT * FROM trips WHERE trip_id = ?", (trip_id,))
|
||||
row = cursor.fetchone()
|
||||
return Trip.from_row(row) if row else None
|
||||
|
||||
def get_by_route_id(self, route_id: str) -> List[Trip]:
|
||||
cursor = self.db.execute("SELECT * FROM trips WHERE route_id = ?", (route_id,))
|
||||
return [Trip.from_row(row) for row in cursor]
|
||||
|
||||
def get_by_headsign(self, headsign: str) -> List[Trip]:
|
||||
cursor = self.db.execute(
|
||||
"SELECT * FROM trips WHERE trip_headsign LIKE ?",
|
||||
(f"%{headsign}%",)
|
||||
)
|
||||
return [Trip.from_row(row) for row in cursor]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
cursor = self.db.execute("SELECT EXISTS(SELECT 1 FROM trips LIMIT 1)")
|
||||
result = cursor.fetchone()
|
||||
return not result[0]
|
||||
|
||||
def download_data():
|
||||
response = requests.get("https://www.ztm.poznan.pl/pl/dla-deweloperow/getGTFSFile")
|
||||
open("data.zip", "wb").write(response.content)
|
||||
with zipfile.ZipFile("data.zip", "r") as zip_ref:
|
||||
zip_ref.extractall("tmp")
|
||||
|
||||
def get_realtime(trip: Trip):
|
||||
with open("trip_updates.pb", "rb") as f:
|
||||
response = f.read()
|
||||
feed = gtfs_realtime_pb2.FeedMessage() # type: ignore
|
||||
feed.ParseFromString(response)
|
||||
for entity in feed.entity:
|
||||
if str(entity.trip_update.trip.trip_id[:-2]) == str(trip.trip_id):
|
||||
print(entity.id)
|
||||
print(entity.trip_update)
|
||||
|
||||
def main() -> None:
|
||||
with Database("transit.db") as db:
|
||||
stop_repo = StopRepository(db)
|
||||
stop_time_repo = StopTimeRepository(db)
|
||||
trip_repo = TripRepository(db)
|
||||
|
||||
# Create tables
|
||||
stop_repo.create_table()
|
||||
stop_time_repo.create_table()
|
||||
trip_repo.create_table()
|
||||
|
||||
# Load data if tables are empty
|
||||
if stop_repo.is_empty():
|
||||
stop_repo.load_from_csv(Path("tmp/stops.txt"))
|
||||
if stop_time_repo.is_empty():
|
||||
stop_time_repo.load_from_csv(Path("tmp/stop_times.txt"))
|
||||
if trip_repo.is_empty():
|
||||
trip_repo.load_from_csv(Path("tmp/trips.txt"))
|
||||
|
||||
stop = stop_repo.find_by_name("Polanka")[0]
|
||||
for (stop_time, trip) in stop_repo.get_upcoming_departures(stop.stop_id, 1):
|
||||
print(f'Line: {trip.route_id} Departure: {stop_time.arrival_time}')
|
||||
get_realtime(trip)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user