Appendix J — Air Quality prediction model search

Loop and grid search for the optimal air quality model.

import geopandas as gpd
import numpy as np
import pandas as pd

import contextily
import palettable.matplotlib as palmpl
import matplotlib.pyplot as plt
import mapclassify
import libpysal

from utils import legendgram

from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import cross_val_predict, GridSearchCV
/var/folders/2f/fhks6w_d0k556plcv3rfmshw0000gn/T/ipykernel_14176/881383173.py:1: UserWarning: Shapely 2.0 is installed, but because PyGEOS is also installed, GeoPandas will still use PyGEOS by default for now. To force to use and test Shapely 2.0, you have to set the environment variable USE_PYGEOS=0. You can do this before starting the Python process, or in your code before importing geopandas:

import os
os.environ['USE_PYGEOS'] = '0'
import geopandas

In a future release, GeoPandas will switch to using Shapely by default. If you are using PyGEOS directly (calling PyGEOS functions on geometries from GeoPandas), this will then stop working and you are encouraged to migrate from PyGEOS to Shapely 2.0 (https://shapely.readthedocs.io/en/latest/migration_pygeos.html).
  import geopandas as gpd
data_folder = "/Users/martin/Library/CloudStorage/OneDrive-SharedLibraries-TheAlanTuringInstitute/Daniel Arribas-Bel - demoland_data"

Load the data

data = gpd.read_parquet(f"{data_folder}/processed/interpolated/all_oa.parquet")

Filter only explanatory variables.

exvars = data.drop(
    columns=[
        "geo_code",
        "geometry",
        "air_quality_index",
        "house_price_index",
        "jobs_accessibility_index",
        "greenspace_accessibility_index",
    ]
)

Specify grid search parameters. We can limit the options based on previous exploration.

parameters = {"learning_rate": (0.05, 0.1), "max_iter": [500], "max_bins": (64, 128)}

Define the simple weights matrices.

queen = libpysal.weights.Queen.from_dataframe(data)
weights = {
    "queen": queen,
    "queen2": libpysal.weights.higher_order(queen, k=2, lower_order=True),
    "queen3": libpysal.weights.higher_order(queen, k=3, lower_order=True),
    "queen4": libpysal.weights.higher_order(queen, k=4, lower_order=True),
    "queen5": libpysal.weights.higher_order(queen, k=5, lower_order=True),
    "500m": libpysal.weights.DistanceBand.from_dataframe(data, 500),
    "1000m": libpysal.weights.DistanceBand.from_dataframe(data, 1000),
    "2000m": libpysal.weights.DistanceBand.from_dataframe(data, 2000),
}
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected: 
 There are 3 disconnected components.
  warnings.warn(message)
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected: 
 There are 110 disconnected components.
 There are 82 islands with ids: 47, 71, 72, 89, 263, 361, 364, 375, 376, 377, 378, 541, 642, 983, 993, 1092, 1220, 1295, 1339, 1343, 1345, 1383, 1406, 1640, 1756, 1772, 1809, 1851, 1944, 1958, 2124, 2148, 2181, 2182, 2188, 2195, 2214, 2222, 2223, 2237, 2265, 2277, 2281, 2283, 2307, 2361, 2485, 2493, 2594, 2686, 2766, 2809, 2825, 2868, 2940, 2980, 3091, 3094, 3112, 3146, 3191, 3197, 3207, 3223, 3235, 3276, 3397, 3400, 3415, 3419, 3423, 3427, 3451, 3475, 3488, 3528, 3555, 3577, 3707, 3723, 3743, 3778.
  warnings.warn(message)
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected: 
 There are 16 disconnected components.
 There are 12 islands with ids: 89, 377, 378, 1944, 2182, 2277, 2493, 2594, 2868, 3146, 3223, 3528.
  warnings.warn(message)

Use Grid Search CV to find the best model for each weights option.

meta = {}
for name, W in weights.items():
    W.transform = "r"
    exvars = data.drop(
        columns=[
            "geo_code",
            "geometry",
            "air_quality_index",
            "house_price_index",
            "jobs_accessibility_index",
            "greenspace_accessibility_index",
        ]
    )
    for col in exvars.columns.copy():
        exvars[f"{col}_lag"] = libpysal.weights.spatial_lag.lag_spatial(W, exvars[col])
    regressor_lag = HistGradientBoostingRegressor(
        random_state=0,
    )
    est_lag = GridSearchCV(regressor_lag, parameters, verbose=1)
    est_lag.fit(exvars, data.air_quality_index)
    meta[name] = {"score": est_lag.best_score_}
    y_pred_lag = cross_val_predict(
        est_lag.best_estimator_, exvars, data.air_quality_index, cv=5
    )
    pred_lag = pd.Series(y_pred_lag, index=data.index)
    residuals_lag = data.air_quality_index - pred_lag
    meta[name]["mse"] = mean_squared_error(data.air_quality_index, pred_lag)
    meta[name]["me"] = residuals_lag.abs().mean()
    meta[name]["prediction"] = pred_lag
    meta[name]["residuals"] = residuals_lag
    meta[name]["model"] = est_lag.best_estimator_
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
('WARNING: ', 47, ' is an island (no neighbors)')
('WARNING: ', 71, ' is an island (no neighbors)')
('WARNING: ', 72, ' is an island (no neighbors)')
('WARNING: ', 89, ' is an island (no neighbors)')
('WARNING: ', 263, ' is an island (no neighbors)')
('WARNING: ', 361, ' is an island (no neighbors)')
('WARNING: ', 364, ' is an island (no neighbors)')
('WARNING: ', 375, ' is an island (no neighbors)')
('WARNING: ', 376, ' is an island (no neighbors)')
('WARNING: ', 377, ' is an island (no neighbors)')
('WARNING: ', 378, ' is an island (no neighbors)')
('WARNING: ', 541, ' is an island (no neighbors)')
('WARNING: ', 642, ' is an island (no neighbors)')
('WARNING: ', 983, ' is an island (no neighbors)')
('WARNING: ', 993, ' is an island (no neighbors)')
('WARNING: ', 1092, ' is an island (no neighbors)')
('WARNING: ', 1220, ' is an island (no neighbors)')
('WARNING: ', 1295, ' is an island (no neighbors)')
('WARNING: ', 1339, ' is an island (no neighbors)')
('WARNING: ', 1343, ' is an island (no neighbors)')
('WARNING: ', 1345, ' is an island (no neighbors)')
('WARNING: ', 1383, ' is an island (no neighbors)')
('WARNING: ', 1406, ' is an island (no neighbors)')
('WARNING: ', 1640, ' is an island (no neighbors)')
('WARNING: ', 1756, ' is an island (no neighbors)')
('WARNING: ', 1772, ' is an island (no neighbors)')
('WARNING: ', 1809, ' is an island (no neighbors)')
('WARNING: ', 1851, ' is an island (no neighbors)')
('WARNING: ', 1944, ' is an island (no neighbors)')
('WARNING: ', 1958, ' is an island (no neighbors)')
('WARNING: ', 2124, ' is an island (no neighbors)')
('WARNING: ', 2148, ' is an island (no neighbors)')
('WARNING: ', 2181, ' is an island (no neighbors)')
('WARNING: ', 2182, ' is an island (no neighbors)')
('WARNING: ', 2188, ' is an island (no neighbors)')
('WARNING: ', 2195, ' is an island (no neighbors)')
('WARNING: ', 2214, ' is an island (no neighbors)')
('WARNING: ', 2222, ' is an island (no neighbors)')
('WARNING: ', 2223, ' is an island (no neighbors)')
('WARNING: ', 2237, ' is an island (no neighbors)')
('WARNING: ', 2265, ' is an island (no neighbors)')
('WARNING: ', 2277, ' is an island (no neighbors)')
('WARNING: ', 2281, ' is an island (no neighbors)')
('WARNING: ', 2283, ' is an island (no neighbors)')
('WARNING: ', 2307, ' is an island (no neighbors)')
('WARNING: ', 2361, ' is an island (no neighbors)')
('WARNING: ', 2485, ' is an island (no neighbors)')
('WARNING: ', 2493, ' is an island (no neighbors)')
('WARNING: ', 2594, ' is an island (no neighbors)')
('WARNING: ', 2686, ' is an island (no neighbors)')
('WARNING: ', 2766, ' is an island (no neighbors)')
('WARNING: ', 2809, ' is an island (no neighbors)')
('WARNING: ', 2825, ' is an island (no neighbors)')
('WARNING: ', 2868, ' is an island (no neighbors)')
('WARNING: ', 2940, ' is an island (no neighbors)')
('WARNING: ', 2980, ' is an island (no neighbors)')
('WARNING: ', 3091, ' is an island (no neighbors)')
('WARNING: ', 3094, ' is an island (no neighbors)')
('WARNING: ', 3112, ' is an island (no neighbors)')
('WARNING: ', 3146, ' is an island (no neighbors)')
('WARNING: ', 3191, ' is an island (no neighbors)')
('WARNING: ', 3197, ' is an island (no neighbors)')
('WARNING: ', 3207, ' is an island (no neighbors)')
('WARNING: ', 3223, ' is an island (no neighbors)')
('WARNING: ', 3235, ' is an island (no neighbors)')
('WARNING: ', 3276, ' is an island (no neighbors)')
('WARNING: ', 3397, ' is an island (no neighbors)')
('WARNING: ', 3400, ' is an island (no neighbors)')
('WARNING: ', 3415, ' is an island (no neighbors)')
('WARNING: ', 3419, ' is an island (no neighbors)')
('WARNING: ', 3423, ' is an island (no neighbors)')
('WARNING: ', 3427, ' is an island (no neighbors)')
('WARNING: ', 3451, ' is an island (no neighbors)')
('WARNING: ', 3475, ' is an island (no neighbors)')
('WARNING: ', 3488, ' is an island (no neighbors)')
('WARNING: ', 3528, ' is an island (no neighbors)')
('WARNING: ', 3555, ' is an island (no neighbors)')
('WARNING: ', 3577, ' is an island (no neighbors)')
('WARNING: ', 3707, ' is an island (no neighbors)')
('WARNING: ', 3723, ' is an island (no neighbors)')
('WARNING: ', 3743, ' is an island (no neighbors)')
('WARNING: ', 3778, ' is an island (no neighbors)')
Fitting 5 folds for each of 4 candidates, totalling 20 fits
('WARNING: ', 89, ' is an island (no neighbors)')
('WARNING: ', 377, ' is an island (no neighbors)')
('WARNING: ', 378, ' is an island (no neighbors)')
('WARNING: ', 1944, ' is an island (no neighbors)')
('WARNING: ', 2182, ' is an island (no neighbors)')
('WARNING: ', 2277, ' is an island (no neighbors)')
('WARNING: ', 2493, ' is an island (no neighbors)')
('WARNING: ', 2594, ' is an island (no neighbors)')
('WARNING: ', 2868, ' is an island (no neighbors)')
('WARNING: ', 3146, ' is an island (no neighbors)')
('WARNING: ', 3223, ' is an island (no neighbors)')
('WARNING: ', 3528, ' is an island (no neighbors)')
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits

Add combined weights on top.

combined_weigts = {
    "queen500m": libpysal.weights.w_union(weights["queen"], weights["500m"]),
    "queen1000m": libpysal.weights.w_union(weights["queen"], weights["1000m"]),
    "queen2000m": libpysal.weights.w_union(weights["queen"], weights["2000m"]),
}

Find models.

for name, W in combined_weigts.items():
    W.transform = "r"
    exvars = data.drop(
        columns=[
            "geo_code",
            "geometry",
            "air_quality_index",
            "house_price_index",
            "jobs_accessibility_index",
            "greenspace_accessibility_index",
        ]
    )
    for col in exvars.columns.copy():
        exvars[f"{col}_lag"] = libpysal.weights.spatial_lag.lag_spatial(W, exvars[col])
    regressor_lag = HistGradientBoostingRegressor(
        random_state=0,
    )
    est_lag = GridSearchCV(regressor_lag, parameters, verbose=1)
    est_lag.fit(exvars, data.air_quality_index)
    meta[name] = {"score": est_lag.best_score_}
    y_pred_lag = cross_val_predict(
        est_lag.best_estimator_, exvars, data.air_quality_index, cv=5
    )
    pred_lag = pd.Series(y_pred_lag, index=data.index)
    residuals_lag = data.air_quality_index - pred_lag
    meta[name]["mse"] = mean_squared_error(data.air_quality_index, pred_lag)
    meta[name]["me"] = residuals_lag.abs().mean()
    meta[name]["prediction"] = pred_lag
    meta[name]["residuals"] = residuals_lag
    meta[name]["model"] = est_lag.best_estimator_
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits

Save evaluation metrics as series.

mse = pd.Series([vals["mse"] for vals in meta.values()], index=meta.keys())
me = pd.Series([vals["me"] for vals in meta.values()], index=meta.keys())
score = pd.Series([vals["score"] for vals in meta.values()], index=meta.keys())

Sort according to MSE. Lower is better.

mse.sort_values()
queen2000m    0.505627
queen4        0.510048
2000m         0.520661
queen5        0.523149
queen3        0.544779
1000m         0.671224
queen1000m    0.677756
queen2        0.714267
queen500m     0.826434
queen         0.919178
500m          0.947078
dtype: float64

Sort according to ME. Lower is better.

me.sort_values()
queen2000m    0.445121
2000m         0.453738
queen4        0.481805
queen5        0.486240
queen3        0.514920
1000m         0.551970
queen1000m    0.558655
queen2        0.601619
queen500m     0.658718
500m          0.708802
queen         0.712700
dtype: float64

Sort according to R2. Higher is better.

score.sort_values()
queen         0.553838
500m          0.563087
queen500m     0.615798
queen2        0.690342
queen1000m    0.721429
1000m         0.724661
queen3        0.762028
queen5        0.780816
queen4        0.782347
2000m         0.803538
queen2000m    0.807554
dtype: float64

The optimal model seems to use a combination of Queen weights and Distance Band 2000m. Let’s explore it.

The actual vs predicted values.

fig, ax = plt.subplots(figsize=(8, 8))
plt.scatter(data.air_quality_index, meta["queen2000m"]["prediction"], s=0.25)
plt.xlabel("Y test")
plt.ylabel("Y pred")
Text(0, 0.5, 'Y pred')

Plot the values using the original cmap.

from shapely.geometry import box

bds = data.total_bounds
extent = gpd.GeoSeries(
    [box((bds[0] - 7000), bds[1], bds[2] + 7000, bds[3])], crs=data.crs
).to_crs(3857)
f, ax = plt.subplots(figsize=(18, 12))
extent.plot(ax=ax, alpha=0)
bins = mapclassify.EqualInterval(data["air_quality_index"].values, k=20).bins

data.assign(pred=meta["queen2000m"]["prediction"]).to_crs(3857).plot(
    "pred",
    scheme="userdefined",
    classification_kwds={"bins": bins},
    ax=ax,
    alpha=0.9,
    cmap="magma_r",
)
legendgram(
    f,
    ax,
    meta["queen2000m"]["prediction"],
    bins,
    pal=palmpl.Magma_20_r,
    legend_size=(0.35, 0.15),  # legend size in fractions of the axis
    loc="lower left",  # matplotlib-style legend locations
    clip=(10, 20),  # clip the displayed range of the histogram
)
ax.set_axis_off()
contextily.add_basemap(
    ax=ax, source=contextily.providers.CartoDB.PositronNoLabels, attribution=""
)
contextily.add_basemap(
    ax=ax,
    source=contextily.providers.Stamen.TonerLines,
    alpha=0.4,
    attribution="(C) CARTO, Map tiles by Stamen Design, CC BY 3.0 -- Map data (C) OpenStreetMap contributors",
)
# plt.savefig(f"{data_folder}/outputs/figures/air_quality_index.png", dpi=150, bbox_inches="tight")

Plot residuals.

f, ax = plt.subplots(figsize=(18, 12))
extent.plot(ax=ax, alpha=0)
data.assign(res=meta["queen2000m"]["residuals"]).to_crs(3857).plot(
    "res", ax=ax, alpha=0.9, cmap="RdBu", vmin=-4, vmax=4, legend=True
)
ax.set_axis_off()
contextily.add_basemap(
    ax=ax, source=contextily.providers.CartoDB.PositronNoLabels, attribution=""
)
contextily.add_basemap(
    ax=ax,
    source=contextily.providers.Stamen.TonerLines,
    alpha=0.4,
    attribution="(C) CARTO, Map tiles by Stamen Design, CC BY 3.0 -- Map data (C) OpenStreetMap contributors",
)
# plt.savefig(f"{data_folder}/outputs/figures/air_quality_index.png", dpi=150, bbox_inches="tight")

It seems that large OAs are less precise. Green belt tends to be overpredicted whuile some more central areas underpredicted. However, the error across the area is minimal.

We can save the meta dict with all the data. The final model is part of that.

import pickle
with open(f"{data_folder}/models/air_quality_meta.pickle", "wb") as f:
    pickle.dump(meta, f)

Save just the model for easy inference.

with open(f"{data_folder}/models/air_quality_model.pickle", "wb") as f:
    pickle.dump(meta["queen2000m"]["model"], f)