import geopandas as gpd
import numpy as np
import pandas as pd
import contextily
import palettable.matplotlib as palmpl
import matplotlib.pyplot as plt
import mapclassify
import libpysal
from utils import legendgram
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import cross_val_predict, GridSearchCV
Appendix M — House price prediction model search
Loop and grid search for the optimal house price model.
= "/Users/martin/Library/CloudStorage/OneDrive-SharedLibraries-TheAlanTuringInstitute/Daniel Arribas-Bel - demoland_data" data_folder
Load the data
= gpd.read_parquet(f"{data_folder}/processed/interpolated/all_oa.parquet") data
Filter only explanatory variables.
= data.drop(
exvars =[
columns"geo_code",
"geometry",
"air_quality_index",
"house_price_index",
"jobs_accessibility_index",
"greenspace_accessibility_index",
] )
Specify grid search parameters. We can limit the options based on previous exploration.
= {"learning_rate": (0.05, 0.1), "max_iter": [500], "max_bins": (64, 128)} parameters
Define the simple weights matrices.
= libpysal.weights.Queen.from_dataframe(data)
queen = {
weights "queen": queen,
"queen2": libpysal.weights.higher_order(queen, k=2, lower_order=True),
"queen3": libpysal.weights.higher_order(queen, k=3, lower_order=True),
"queen4": libpysal.weights.higher_order(queen, k=4, lower_order=True),
"queen5": libpysal.weights.higher_order(queen, k=5, lower_order=True),
"500m": libpysal.weights.DistanceBand.from_dataframe(data, 500),
"1000m": libpysal.weights.DistanceBand.from_dataframe(data, 1000),
"2000m": libpysal.weights.DistanceBand.from_dataframe(data, 2000),
}
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected:
There are 3 disconnected components.
warnings.warn(message)
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected:
There are 110 disconnected components.
There are 82 islands with ids: 47, 71, 72, 89, 263, 361, 364, 375, 376, 377, 378, 541, 642, 983, 993, 1092, 1220, 1295, 1339, 1343, 1345, 1383, 1406, 1640, 1756, 1772, 1809, 1851, 1944, 1958, 2124, 2148, 2181, 2182, 2188, 2195, 2214, 2222, 2223, 2237, 2265, 2277, 2281, 2283, 2307, 2361, 2485, 2493, 2594, 2686, 2766, 2809, 2825, 2868, 2940, 2980, 3091, 3094, 3112, 3146, 3191, 3197, 3207, 3223, 3235, 3276, 3397, 3400, 3415, 3419, 3423, 3427, 3451, 3475, 3488, 3528, 3555, 3577, 3707, 3723, 3743, 3778.
warnings.warn(message)
/Users/martin/mambaforge/envs/demoland/lib/python3.11/site-packages/libpysal/weights/weights.py:172: UserWarning: The weights matrix is not fully connected:
There are 16 disconnected components.
There are 12 islands with ids: 89, 377, 378, 1944, 2182, 2277, 2493, 2594, 2868, 3146, 3223, 3528.
warnings.warn(message)
Add combined weights on top.
"queen500m"] = libpysal.weights.w_union(weights["queen"], weights["500m"])
weights["queen1000m"] = libpysal.weights.w_union(weights["queen"], weights["1000m"])
weights["queen2000m"] = libpysal.weights.w_union(weights["queen"], weights["2000m"]) weights[
Get a mask to ignore missing values.
= data.house_price_index.notna() mask
Use Grid Search CV to find the best model for each weights option.
= {}
meta for name, W in weights.items():
= "r"
W.transform = data.drop(
exvars =[
columns"geo_code",
"geometry",
"air_quality_index",
"house_price_index",
"jobs_accessibility_index",
"greenspace_accessibility_index",
]
)for col in exvars.columns.copy():
f"{col}_lag"] = libpysal.weights.spatial_lag.lag_spatial(W, exvars[col])
exvars[= HistGradientBoostingRegressor(
regressor_lag =0,
random_state
)= GridSearchCV(regressor_lag, parameters, verbose=1)
est_lag
est_lag.fit(exvars[mask], data.house_price_index[mask])= {"score": est_lag.best_score_}
meta[name] = cross_val_predict(
y_pred_lag =5
est_lag.best_estimator_, exvars[mask], data.house_price_index[mask], cv
)= pd.Series(y_pred_lag, index=data.index[mask])
pred_lag = data.house_price_index[mask] - pred_lag
residuals_lag "mse"] = mean_squared_error(data.house_price_index[mask], pred_lag)
meta[name]["me"] = residuals_lag.abs().mean()
meta[name]["prediction"] = pred_lag
meta[name]["residuals"] = residuals_lag
meta[name]["model"] = est_lag.best_estimator_ meta[name][
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
('WARNING: ', 47, ' is an island (no neighbors)')
('WARNING: ', 71, ' is an island (no neighbors)')
('WARNING: ', 72, ' is an island (no neighbors)')
('WARNING: ', 89, ' is an island (no neighbors)')
('WARNING: ', 263, ' is an island (no neighbors)')
('WARNING: ', 361, ' is an island (no neighbors)')
('WARNING: ', 364, ' is an island (no neighbors)')
('WARNING: ', 375, ' is an island (no neighbors)')
('WARNING: ', 376, ' is an island (no neighbors)')
('WARNING: ', 377, ' is an island (no neighbors)')
('WARNING: ', 378, ' is an island (no neighbors)')
('WARNING: ', 541, ' is an island (no neighbors)')
('WARNING: ', 642, ' is an island (no neighbors)')
('WARNING: ', 983, ' is an island (no neighbors)')
('WARNING: ', 993, ' is an island (no neighbors)')
('WARNING: ', 1092, ' is an island (no neighbors)')
('WARNING: ', 1220, ' is an island (no neighbors)')
('WARNING: ', 1295, ' is an island (no neighbors)')
('WARNING: ', 1339, ' is an island (no neighbors)')
('WARNING: ', 1343, ' is an island (no neighbors)')
('WARNING: ', 1345, ' is an island (no neighbors)')
('WARNING: ', 1383, ' is an island (no neighbors)')
('WARNING: ', 1406, ' is an island (no neighbors)')
('WARNING: ', 1640, ' is an island (no neighbors)')
('WARNING: ', 1756, ' is an island (no neighbors)')
('WARNING: ', 1772, ' is an island (no neighbors)')
('WARNING: ', 1809, ' is an island (no neighbors)')
('WARNING: ', 1851, ' is an island (no neighbors)')
('WARNING: ', 1944, ' is an island (no neighbors)')
('WARNING: ', 1958, ' is an island (no neighbors)')
('WARNING: ', 2124, ' is an island (no neighbors)')
('WARNING: ', 2148, ' is an island (no neighbors)')
('WARNING: ', 2181, ' is an island (no neighbors)')
('WARNING: ', 2182, ' is an island (no neighbors)')
('WARNING: ', 2188, ' is an island (no neighbors)')
('WARNING: ', 2195, ' is an island (no neighbors)')
('WARNING: ', 2214, ' is an island (no neighbors)')
('WARNING: ', 2222, ' is an island (no neighbors)')
('WARNING: ', 2223, ' is an island (no neighbors)')
('WARNING: ', 2237, ' is an island (no neighbors)')
('WARNING: ', 2265, ' is an island (no neighbors)')
('WARNING: ', 2277, ' is an island (no neighbors)')
('WARNING: ', 2281, ' is an island (no neighbors)')
('WARNING: ', 2283, ' is an island (no neighbors)')
('WARNING: ', 2307, ' is an island (no neighbors)')
('WARNING: ', 2361, ' is an island (no neighbors)')
('WARNING: ', 2485, ' is an island (no neighbors)')
('WARNING: ', 2493, ' is an island (no neighbors)')
('WARNING: ', 2594, ' is an island (no neighbors)')
('WARNING: ', 2686, ' is an island (no neighbors)')
('WARNING: ', 2766, ' is an island (no neighbors)')
('WARNING: ', 2809, ' is an island (no neighbors)')
('WARNING: ', 2825, ' is an island (no neighbors)')
('WARNING: ', 2868, ' is an island (no neighbors)')
('WARNING: ', 2940, ' is an island (no neighbors)')
('WARNING: ', 2980, ' is an island (no neighbors)')
('WARNING: ', 3091, ' is an island (no neighbors)')
('WARNING: ', 3094, ' is an island (no neighbors)')
('WARNING: ', 3112, ' is an island (no neighbors)')
('WARNING: ', 3146, ' is an island (no neighbors)')
('WARNING: ', 3191, ' is an island (no neighbors)')
('WARNING: ', 3197, ' is an island (no neighbors)')
('WARNING: ', 3207, ' is an island (no neighbors)')
('WARNING: ', 3223, ' is an island (no neighbors)')
('WARNING: ', 3235, ' is an island (no neighbors)')
('WARNING: ', 3276, ' is an island (no neighbors)')
('WARNING: ', 3397, ' is an island (no neighbors)')
('WARNING: ', 3400, ' is an island (no neighbors)')
('WARNING: ', 3415, ' is an island (no neighbors)')
('WARNING: ', 3419, ' is an island (no neighbors)')
('WARNING: ', 3423, ' is an island (no neighbors)')
('WARNING: ', 3427, ' is an island (no neighbors)')
('WARNING: ', 3451, ' is an island (no neighbors)')
('WARNING: ', 3475, ' is an island (no neighbors)')
('WARNING: ', 3488, ' is an island (no neighbors)')
('WARNING: ', 3528, ' is an island (no neighbors)')
('WARNING: ', 3555, ' is an island (no neighbors)')
('WARNING: ', 3577, ' is an island (no neighbors)')
('WARNING: ', 3707, ' is an island (no neighbors)')
('WARNING: ', 3723, ' is an island (no neighbors)')
('WARNING: ', 3743, ' is an island (no neighbors)')
('WARNING: ', 3778, ' is an island (no neighbors)')
Fitting 5 folds for each of 4 candidates, totalling 20 fits
('WARNING: ', 89, ' is an island (no neighbors)')
('WARNING: ', 377, ' is an island (no neighbors)')
('WARNING: ', 378, ' is an island (no neighbors)')
('WARNING: ', 1944, ' is an island (no neighbors)')
('WARNING: ', 2182, ' is an island (no neighbors)')
('WARNING: ', 2277, ' is an island (no neighbors)')
('WARNING: ', 2493, ' is an island (no neighbors)')
('WARNING: ', 2594, ' is an island (no neighbors)')
('WARNING: ', 2868, ' is an island (no neighbors)')
('WARNING: ', 3146, ' is an island (no neighbors)')
('WARNING: ', 3223, ' is an island (no neighbors)')
('WARNING: ', 3528, ' is an island (no neighbors)')
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Save evaluation metrics as series.
= pd.Series([vals["mse"] for vals in meta.values()], index=meta.keys())
mse = pd.Series([vals["me"] for vals in meta.values()], index=meta.keys())
me = pd.Series([vals["score"] for vals in meta.values()], index=meta.keys()) score
Sort according to MSE. Lower is better.
mse.sort_values()
queen5 160164.675606
2000m 160245.606116
queen2000m 160766.019742
1000m 164343.481698
queen1000m 165435.671928
queen4 167802.714113
queen3 170864.677687
queen2 177496.106149
queen500m 185800.082368
500m 186592.783587
queen 201453.657786
dtype: float64
Sort according to ME. Lower is better.
me.sort_values()
2000m 304.115638
queen2000m 305.246902
queen5 306.013012
1000m 310.392479
queen4 311.826728
queen1000m 312.101787
queen3 317.151294
queen2 324.171353
queen500m 330.578553
500m 332.239067
queen 345.019843
dtype: float64
Sort according to R2. Higher is better.
score.sort_values()
queen 0.385277
500m 0.431761
queen500m 0.433833
queen2 0.454928
queen3 0.468809
queen4 0.480325
1000m 0.486258
queen1000m 0.489035
2000m 0.497828
queen2000m 0.498994
queen5 0.504176
dtype: float64
The optimal model seems to use either Queen 5, 2000m or a combination. As the original distribution of prices can be a bit bumpy, let’s stick to Q5. Let’s explore it.
The actual vs predicted values.
= plt.subplots(figsize=(8, 8))
fig, ax "queen5"]["prediction"], s=0.25)
plt.scatter(data.house_price_index[mask], meta["Y test")
plt.xlabel("Y pred") plt.ylabel(
Text(0, 0.5, 'Y pred')
Plot the values using the original cmap.
from shapely.geometry import box
= data.total_bounds
bds = gpd.GeoSeries(
extent 0] - 7000), bds[1], bds[2] + 7000, bds[3])], crs=data.crs
[box((bds[3857) ).to_crs(
= plt.subplots(figsize=(18, 12))
f, ax =ax, alpha=0)
extent.plot(ax= mapclassify.NaturalBreaks(data["house_price_index"].dropna().values, k=10).bins
bins
=meta["queen5"]["prediction"]).to_crs(3857).plot(
data.assign(pred"pred",
="userdefined",
scheme={"bins": bins},
classification_kwds=ax,
ax=0.9,
alpha="viridis",
cmap
)
legendgram(
f,
ax,"queen5"]["prediction"],
meta[
bins,=palmpl.Viridis_10,
pal=(0.35, 0.15), # legend size in fractions of the axis
legend_size="lower left", # matplotlib-style legend locations
loc=(
clip0,
"house_price_index"].max(),
data[# clip the displayed range of the histogram
),
)
ax.set_axis_off()
contextily.add_basemap(=ax, source=contextily.providers.CartoDB.PositronNoLabels, attribution=""
ax
)
contextily.add_basemap(=ax,
ax=contextily.providers.Stamen.TonerLines,
source=0.4,
alpha="(C) CARTO, Map tiles by Stamen Design, CC BY 3.0 -- Map data (C) OpenStreetMap contributors",
attribution
)# plt.savefig(f"{data_folder}/outputs/figures/air_quality_index.png", dpi=150, bbox_inches="tight")
Plot residuals.
= plt.subplots(figsize=(18, 12))
f, ax =ax, alpha=0)
extent.plot(ax=meta["queen5"]["residuals"]).to_crs(3857).plot(
data.assign(res"res", ax=ax, alpha=0.9, cmap="RdBu", vmin=-2000, vmax=2000, legend=True
)
ax.set_axis_off()
contextily.add_basemap(=ax, source=contextily.providers.CartoDB.PositronNoLabels, attribution=""
ax
)
contextily.add_basemap(=ax,
ax=contextily.providers.Stamen.TonerLines,
source=0.4,
alpha="(C) CARTO, Map tiles by Stamen Design, CC BY 3.0 -- Map data (C) OpenStreetMap contributors",
attribution
)# plt.savefig(f"{data_folder}/outputs/figures/air_quality_index.png", dpi=150, bbox_inches="tight")
"queen5"]["residuals"].plot.hist(bins=25) meta[
There doesn’t seem to be any general pattern in where we over- and where underpredict. The general tendency seems to be captured and a comparison against this base model shall work in the final app and scenario building.
We can save the meta dict with all the data. The final model is part of that.
import pickle
with open(f"{data_folder}/models/house_price_meta.pickle", "wb") as f:
pickle.dump(meta, f)
Save just the model for easy inference.
with open(f"{data_folder}/models/house_price_model.pickle", "wb") as f:
"queen5"]["model"], f) pickle.dump(meta[