🐮

【EDA Method】How to plot correlation of DataFrame columns

に公開

Plot correlation

This time, introduce how to plot the correlation of columns of dataframe.

It's so useful. Try it to visualization and data analysis.

def plot_target_correlation(
    df: pd.DataFrame,
    target_col: str = 'target',
    n_top_features: int = 30,
    color_sequence: list[str] | None = None,
    template_theme: str = "plotly_white"
) -> None:
    """Create a correlation plot showing the top correlated features with the target variable.

    Args:
        df (pd.DataFrame): 
            The input dataframe containing feature and target columns.
        target_col (str, optional): 
            Name of the target column.
        n_top_features (int, optional): 
            Number of top correlated features to display.
        color_sequence (list[str], optional): 
            Custom color sequence for the plot.
            Defaults to None, which uses a blue to red color scale.
        template_theme (str, optional): 
            Plotly template theme for visual styling.

    Returns:
        None: This function displays a plot and doesn't return anything.
    """
    # Calculate correlations
    correlations = df.corr()[target_col]
    
    # Sort by absolute correlation value
    correlations_abs = correlations.abs().sort_values(ascending=False)
    
    # Select top correlated features (excluding the target itself)
    top_correlations = correlations[correlations_abs.index[1:n_top_features+1]]
    
    # Prepare data for plotting
    feature_names = top_correlations.index
    correlation_values = top_correlations.values
    
    # Set up color scale
    if color_sequence is None:
        color_sequence = ['#0d0887', '#46039f', '#7201a8', '#9c179e', '#bd3786', '#d8576b', '#ed7953', '#fb9f3a', '#fdca26', '#f0f921']
    
    # Create the bar plot
    fig = go.Figure()
    fig.add_trace(go.Bar(
        y=feature_names,
        x=correlation_values,
        orientation='h',
        marker=dict(
            color=correlation_values,
            colorscale=color_sequence,
            colorbar=dict(title="Correlation"),
        )
    ))
    
    # Customize the layout
    fig.update_layout(
        title=f"<b>Top {n_top_features} Features Correlated with {target_col.capitalize()}</b>",
        xaxis_title="<b>Correlation Coefficient</b>",
        yaxis_title="<b>Feature</b>",
        height=800,
        width=1200,
        template=template_theme,
    )
    
    # Add vertical line at x=0 for reference
    fig.add_shape(
        type="line",
        x0=0, y0=-0.5,
        x1=0, y1=len(feature_names) - 0.5,
        line=dict(color="black", width=1, dash="dash")
    )
    
    # Display the plot
    fig.show()
    
plot_target_correlation(preprocessed_train_df)


Discussion