shap.dependence_plot
¶This notebook is designed to demonstrate (and so document) how to use the shap.dependence_plot
function. It uses an XGBoost model trained on the classic UCI adult income dataset (which is classification task to predict if people made over 50k in the 90s).
import xgboost
import shap
# train XGBoost model
X,y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
# compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
A dependence plot is a scatter plot that shows the effect a single feature has on the predictions made by the model. In this example the log-odds of making over 50k increases significantly between age 20 and 40.
# The first argument is the index of the feature we want to plot
# The second argument is the matrix of SHAP values (it is the same shape as the data matrix)
# The third argument is the data matrix (a pandas dataframe or numpy array)
shap.dependence_plot(0, shap_values, X)
# If we pass a numpy array instead of a data frame then we
# need pass the feature names in separately
shap.dependence_plot(0, shap_values, X.values, feature_names=X.columns)
# We can pass a feature name instead of an index
shap.dependence_plot("Age", shap_values, X)
# We can also use the special "rank(i)" systax to specify the i'th most
# important feature to the model. As measured by: np.abs(shap_values).mean(0)
# In this example age is the second most important feature.
shap.dependence_plot("rank(1)", shap_values, X)
# The interaction_index argument can be used to explicitly
# set which feature gets used for coloring
shap.dependence_plot("rank(1)", shap_values, X, interaction_index="Education-Num")
# we can turn off interaction coloring
shap.dependence_plot("Age", shap_values, X, interaction_index=None)
# we can use shap.approximate_interactions to guess which features
# may interact with age
inds = shap.approximate_interactions("Age", shap_values, X)
# make plots colored by each of the top three possible interacting features
for i in range(3):
shap.dependence_plot("Age", shap_values, X, interaction_index=inds[i])
import matplotlib.pyplot as plt
# you can use the cmap parameter to provide your own custom color map
shap.dependence_plot("Age", shap_values, X, cmap=plt.get_cmap("cool"))
# by passing show=False you can prevent shap.dependence_plot from calling
# the matplotlib show() function, and so you can keep customizing the plot
# before eventually calling show yourself
shap.dependence_plot(0, shap_values, X, show=False)
plt.title("Age dependence plot")
plt.ylabel("SHAP value for the 'Age' feature")
# plt.savefig("my_dependence_plot.pdf") # we can save a PDF of the figure if we want
plt.show()
# you can use xmax and xmin with a percentile notation to hide outliers
shap.dependence_plot(0, shap_values, X, xmin="percentile(1)", xmax="percentile(99)")
# transparency can help reveal dense vs. sparse areas of the scatter plot
shap.dependence_plot(0, shap_values, X, alpha=0.1)
# an alternative to transparency is to reduce the dot size
shap.dependence_plot(0, shap_values, X, dot_size=2)
# for categorical (or binned) data adding a small amount of x-jitter makes
# thin columns of dots more readable
shap.dependence_plot(0, shap_values, X, x_jitter=1, dot_size=1)
X_cat = X.copy()
relationship_decoding = {
0: 'Not-in-family',
1: 'Unmarried',
2: 'Other-relative',
3: 'Own-child',
4: 'Husband',
5: 'Wife'
}
X_cat["Relationship"] = X_cat["Relationship"].map(relationship_decoding)
X_cat.head(3)
# You can use string-valued category features
shap.dependence_plot("Relationship", shap_values, X_cat)
# It is also possible to use string-valued features to plot interaction effect
shap.dependence_plot(0, shap_values, X_cat, interaction_index="Relationship")