back to home

nyc eta engine

#python#pytorch#lightgbm#pandas#mlflow#docker#hugging face hub

3-model ensemble (neural net + lightgbm + ft-transformer) for nyc taxi eta prediction. trained on 37m trips, achieves 252.7s mae — 28% better than xgboost baseline, inference under 5ms.

/ what it does

  • predicts taxi trip duration given pickup zone, dropoff zone, timestamp, and passenger count using 37 million real nyc yellow taxi trips from 2023
  • learns all spatial relationships from trip data via zone embeddings — no external geography, shapefiles, or hardcoded coordinates. if zone ids mapped to a different city, the model would work equally well
  • serves predictions in under 5ms on cpu with a 3-model ensemble, packaged in a ~500mb docker container

/ the ensemble

  • model 1: dual-branch embedding network (560k params) — zone embeddings with hash-based pair embedding, 24 continuous features, residual blocks. best at smooth interpolation for common routes
  • model 2: lightgbm (81 trees) — gradient-boosted trees with zone ids as native categoricals. near-zero bias (-6s) on rare pairs where the neural net struggles (-106s bias)
  • model 3: ft-transformer (406k params) — implemented from scratch, each feature projected into a 128-dim token, 3-layer self-attention with [cls] aggregation. positive bias (+65s) offsets nn's negative bias
  • ensemble weights optimized via grid search on full dev set: 0.6 nn + 0.2 lgbm + 0.2 ft-transformer

/ feature engineering

  • 14 zone-pair statistics with bayesian shrinkage — smooths sparse pairs toward pickup-zone mean with a fallback hierarchy: pair → pickup zone → dropoff zone → global mean
  • 6 traffic-regime time buckets (late night, early morning, am rush, midday, pm rush, evening) with per-regime pair statistics
  • 10 temporal features: cyclical hour/dow/month encoding, rush hour flags, night flags, normalized minute-of-day
  • zone-pair median alone (296.7s) beats xgboost (351s) with zero ml — the signal is in the feature engineering

/ results

  • 252.7s mae — 28% better than xgboost baseline (351s). ensemble reduced mae from 261s (best single model) to 253s
  • nn: 261.2s (precision on common routes) | lgbm: 261.7s (low bias on rare pairs) | ft: 284.7s (different error pattern, bias offset)
  • diagnostic-driven tuning identified rare-pair bias as the true bottleneck — no amount of nn tuning could fix it, lightgbm solved it
  • inference under 5ms per request, total model weights 6.3mb (2.3 + 2.4 + 1.6)

/ how it works

01download and clean 37m nyc taxi trips (2023), split temporally into train/dev
02compute zone-pair statistics with bayesian shrinkage across 6 traffic regimes
03train neural net (37m rows, huber loss), lightgbm (10m rows, mae), ft-transformer (10m rows, l1)
04optimize ensemble weights via grid search on full 1.23m dev set
05evaluate on held-out dev set, pick best checkpoint by mae (not training loss)

/ features

3-model ensemble
neural net + lightgbm + ft-transformer with complementary strengths. each model has a different inductive bias — embeddings vs tree splits vs self-attention. ensemble reduced mae from 261s to 253s.
ft-transformer from scratch
feature tokenizer transformer (gorishniy et al., neurips 2021). each feature projected into a 128-dim token, [cls] token aggregates via 3-layer self-attention. captures cross-feature interactions the mlp misses.
learned zone embeddings
50-dim embeddings for 266 zones learn spatial relationships purely from trip patterns. no external geography needed — model is transferable to any city with zone ids.
bayesian shrinkage for sparse pairs
handles rare and unseen zone pairs gracefully. shrinkage prior smooths toward pickup-zone mean; fallback hierarchy prevents cold-start failures.
diagnostic-driven tuning
deep diagnostics (parameter health, rare-pair analysis, regularization checks) revealed rare-pair bias as the true bottleneck. prevented wasted experiments on architecture changes.
memory-efficient training
37m rows processed in 2m-row chunks. keeps memory under 6gb, enabling free-tier gpu training on colab/kaggle t4.