🐮
【EDA Method】How to plot correlation of DataFrame columns
Plot correlation
This time, introduce how to plot the correlation of columns of dataframe.
It's so useful. Try it to visualization and data analysis.
def plot_target_correlation(
df: pd.DataFrame,
target_col: str = 'target',
n_top_features: int = 30,
color_sequence: list[str] | None = None,
template_theme: str = "plotly_white"
) -> None:
"""Create a correlation plot showing the top correlated features with the target variable.
Args:
df (pd.DataFrame):
The input dataframe containing feature and target columns.
target_col (str, optional):
Name of the target column.
n_top_features (int, optional):
Number of top correlated features to display.
color_sequence (list[str], optional):
Custom color sequence for the plot.
Defaults to None, which uses a blue to red color scale.
template_theme (str, optional):
Plotly template theme for visual styling.
Returns:
None: This function displays a plot and doesn't return anything.
"""
# Calculate correlations
correlations = df.corr()[target_col]
# Sort by absolute correlation value
correlations_abs = correlations.abs().sort_values(ascending=False)
# Select top correlated features (excluding the target itself)
top_correlations = correlations[correlations_abs.index[1:n_top_features+1]]
# Prepare data for plotting
feature_names = top_correlations.index
correlation_values = top_correlations.values
# Set up color scale
if color_sequence is None:
color_sequence = ['#0d0887', '#46039f', '#7201a8', '#9c179e', '#bd3786', '#d8576b', '#ed7953', '#fb9f3a', '#fdca26', '#f0f921']
# Create the bar plot
fig = go.Figure()
fig.add_trace(go.Bar(
y=feature_names,
x=correlation_values,
orientation='h',
marker=dict(
color=correlation_values,
colorscale=color_sequence,
colorbar=dict(title="Correlation"),
)
))
# Customize the layout
fig.update_layout(
title=f"<b>Top {n_top_features} Features Correlated with {target_col.capitalize()}</b>",
xaxis_title="<b>Correlation Coefficient</b>",
yaxis_title="<b>Feature</b>",
height=800,
width=1200,
template=template_theme,
)
# Add vertical line at x=0 for reference
fig.add_shape(
type="line",
x0=0, y0=-0.5,
x1=0, y1=len(feature_names) - 0.5,
line=dict(color="black", width=1, dash="dash")
)
# Display the plot
fig.show()
plot_target_correlation(preprocessed_train_df)
Discussion