Scatterplot Matrix in Python
How to make scatterplot matrices or sploms natively in Python with Plotly.
Plotly Studio: Transform any dataset into an interactive data application in minutes with AI. Try Plotly Studio now.
Scatter matrix with Plotly Express¶
A scatterplot matrix is a matrix associated to n numerical arrays (data variables), $X_1,X_2,…,X_n$ , of the same length. The cell (i,j) of such a matrix displays the scatter plot of the variable Xi versus Xj.
Here we show the Plotly Express function px.scatter_matrix to plot the scatter matrix for the columns of the dataframe. By default, all columns are considered.
Plotly Express is the easy-to-use, high-level interface to Plotly, which operates on a variety of types of data and produces easy-to-style figures.
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df)
fig.show()
Specify the columns to be represented with the dimensions argument, and set colors using a column of the dataframe:
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df,
dimensions=["sepal_length", "sepal_width", "petal_length", "petal_width"],
color="species")
fig.show()
Styled Scatter Matrix with Plotly Express¶
The scatter matrix plot can be configured thanks to the parameters of px.scatter_matrix, but also thanks to fig.update_traces for fine tuning (see the next section to learn more about the options).
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df,
dimensions=["sepal_length", "sepal_width", "petal_length", "petal_width"],
color="species", symbol="species",
title="Scatter matrix of iris data set",
labels={col:col.replace('_', ' ') for col in df.columns}) # remove underscore
fig.update_traces(diagonal_visible=False)
fig.show()
Scatter matrix (splom) with go.Splom¶
If Plotly Express does not provide a good starting point, it is possible to use the more generic go.Splom class from plotly.graph_objects. All its parameters are documented in the reference page https://plotly.com/python/reference/splom/.
The Plotly splom trace implementation for the scatterplot matrix does not require to set $x=Xi$ , and $y=Xj$, for each scatter plot. All arrays, $X_1,X_2,…,X_n$ , are passed once, through a list of dicts called dimensions, i.e. each array/variable represents a dimension.
A trace of type splom is defined as follows:
trace=go.Splom(dimensions=[dict(label='string-1',
values=X1),
dict(label='string-2',
values=X2),
.
.
.
dict(label='string-n',
values=Xn)],
....
)
The label in each dimension is assigned to the axes titles of the corresponding matrix cell.
Splom of the Iris data set¶
import plotly.graph_objects as go
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
# The Iris dataset contains four data variables, sepal length, sepal width, petal length,
# petal width, for 150 iris flowers. The flowers are labeled as `Iris-setosa`,
# `Iris-versicolor`, `Iris-virginica`.
# Define indices corresponding to flower categories, using pandas label encoding
index_vals = df['class'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='sepal length',
values=df['sepal length']),
dict(label='sepal width',
values=df['sepal width']),
dict(label='petal length',
values=df['petal length']),
dict(label='petal width',
values=df['petal width'])],
text=df['class'],
marker=dict