Data Analytics Bootcamp
  • Syllabus
  • Statistical Thinking
  • SQL
  • Python
  • Tableau
  • Lab
  • Capstone
  1. Python
  2. Python
  3. Session 11: Cohort Analysis
  • Syllabus
  • Statistical Thinking
    • Statistics
      • Statistics Session 01: Data Layers and Bias in Data
      • Statistics Session 02: Data Types
      • Statistics Session 03: Probabilistic Distributions
      • Statistics Session 04: Probabilistic Distributions
      • Statistics Session 05: Sampling
      • Statistics Session 06: Inferential Statistics
      • Slides
        • Course Intro
        • Descriptive Stats
        • Data Types
        • Continuous Distributions
        • Discrete Distributions
        • Sampling
        • Hypothesis Testing
  • SQL
    • SQL
      • Session 01: Intro to Relational Databases
      • Session 02: Intro to PostgreSQL
      • Session 03: DA with SQL | Data Types & Constraints
      • Session 04: DA with SQL | Filtering
      • Session 05: DA with SQL | Numeric Functions
      • Session 06: DA with SQL | String Functions
      • Session 07: DA with SQL | Date Functions
      • Session 08: DA with SQL | JOINs
      • Session 09: DA with SQL | Advanced SQL
      • Session 10: DA with SQL | Advanced SQL Functions
      • Session 11: DA with SQL | UDFs, Stored Procedures
      • Session 12: DA with SQL | Advanced Aggregations
      • Session 13: DA with SQL | Final Project
      • Slides
        • Intro to Relational Databases
        • Intro to PostgreSQL
        • Basic Queries: DDL DLM
        • Filtering
        • Numeric Functions
        • String Functions
        • Date Functions
        • Normalization and JOINs
        • Temporary Tables
        • Advanced SQL Functions
        • Reporting and Analysis with SQL
        • Advanced Aggregations
  • Python
    • Python
      • Session 01: Programming for Data Analysts
      • Session 02: Python basic Syntax, Data Structures
      • Session 03: Introduction to Pandas
      • Session 04: Advanced Pandas
      • Session 05: Intro to Data Visualization
      • Session 06: Data Visualization
      • Session 07: Working with Dates
      • Session 08: Data Visualization | Plotly
      • Session 09: Customer Segmentation | RFM
      • Session 10: A/B Testing
      • Session 11: Cohort Analysis
      • Slides
        • Data Analyst
  • Tableau
    • Tableau
      • Tableau Session 01: Introduction to Tableau
      • Tableau Session 02: Intermediate Visual Analytics
      • Tableau Session 03: Advanced Analytics
      • Tableau Session 04: Dashboard Design & Performance
      • Slides
        • Data Analyst
        • Data Analyst
        • Data Analyst
        • Data Analyst

On this page

  • Cohort Analysis
    • Why Cohort Analysis Matters
    • Simple Mental Model
    • Types of Cohorts
    • Acquisition Cohorts
    • Behavioral Cohorts
    • Segment-Based Cohorts
  • Key Metrics
    • Retention Rate
    • Churn Rate
    • Lifetime Value (LTV)
  • Cohort Analysis | Steps
    • Typical Output (Heatmap Structure)
    • Where It’s Used
    • Key Insight
  • Case Study: Cohort Analysis
    • Business Context
    • Business Objective
    • Analytical Objective
    • Loading Packages
    • Reading the Data
    • Variable Description
    • Exploratory Questions | Who are our customers?
    • Exploratory Questions | Do some customer groups appear more likely to churn?
    • Cohort Analysis
  1. Python
  2. Python
  3. Session 11: Cohort Analysis

Session 11: Cohort Analysis

Cohort Analysis
Pandas
Plotly

Cohort Analysis

Cohort analysis is a technique where you group users/customers into cohorts (groups with a shared characteristic) and track how their behavior evolves over time.

A cohort is typically defined by something like:

  • Signup date (most common → acquisition cohorts)
  • First purchase
  • Marketing channel
  • Behavior (e.g., users who used feature X)

Instead of looking at averages across all users, cohort analysis lets you see how different groups behave differently over time.

Important

For the groupings that are not time-dependent, the term segment is more appropriate to use.

Why Cohort Analysis Matters

Aggregate metrics can be misleading. For example if Overall retention = 70% (looks good)

But:

  • Old users → 90% retention
  • New users → 40% retention

Without cohorts, you miss the problem completely.

Cohort analysis helps you:

  • Identify when users churn
  • Understand why some cohorts perform better
  • Measure impact of campaigns or product changes
  • Separate growth (new users) from retention (existing users)

Simple Mental Model

Think of cohorts as batches of users starting at the same time:

Cohort (Signup Month) Month 0 Month 1 Month 2 Month 3
Jan users 100% 70% 55% 40%
Feb users 100% 60% 45% 30%
Mar users 100% 80% 65% 50%

Insights:

  • February cohort performs worse
  • March cohort performs better
  • Something changed in the system

Types of Cohorts

Acquisition Cohorts

  • Grouped by start date
  • Used for retention analysis

Behavioral Cohorts

  • Grouped by actions
  • Example: users who completed onboarding vs not

Segment-Based Cohorts

  • Grouped by attributes:
    • Country
    • Gender
    • Channel

Key Metrics

Retention Rate

\[ Retention = \frac{\text{Active Users at time t}}{\text{Users in cohort at start}} \]

Churn Rate

\[ Churn = 1 - Retention \]

Lifetime Value (LTV)

Total value generated by a cohort over time.

Cohort Analysis | Steps

  1. Define a cohort: users who signed up in January (monthly)
  2. Track behavior over time: Activity, purchases, churn
  3. Aggregate by time periods: Month 1, Month 2, Month 3…
  4. Visualize:
    • Heatmap
    • Line plots
  5. Compare cohorts: Identify patterns and anomalies

Typical Output (Heatmap Structure)

  • Rows → Cohorts (by acquisition month)
  • Columns → Time since acquisition
  • Values → Retention %

Example interpretation:

  • Early drop-off → onboarding issue
  • Late drop-off → product value issue

Where It’s Used

  • SaaS (retention, churn)
  • E-commerce (repeat purchases)
  • Telecom (churn analysis, campaigns)
  • Mobile apps (engagement tracking)
Tip

In other words, any serivice providing business.


Key Insight

Cohort analysis answers:

What happens to users after they join, not just overall performance?

This makes it one of the most powerful tools in product and marketing analytics.

Case Study: Cohort Analysis

Business Context

A European subscription-based digital service is analyzing customer behavior to better understand retention dynamics over time.

The company acquires customers through multiple channels such as Organic traffic, Paid Advertising, and Referrals. While overall growth has been steady, management has observed that:

  • Some acquisition periods seem to generate more loyal customers
  • Other periods show faster customer drop-off (churn)

The business wants to understand whether these differences are real and what might be driving them.

Business Objective

Identify which acquisition cohorts perform better or worse over time and understand the possible drivers behind these differences.

Analytical Objective

To achieve this, we will:

  • Explore customer composition at acquisition
  • Compare cohorts across different months
  • Introduce churn behavior
  • Build a cohort-based retention analysis
  • Visualize retention using cohort curves and heatmaps

Loading Packages

import pandas as pd
import numpy as np
from scipy import stats
import plotly.express as px

Reading the Data

df = pd.read_csv('../data/cohort/cohort_analysis.csv',
                 parse_dates = ["acquisition_date"])
df.head()
user_id acquisition_date cancellation_month gender marital_status age income_segment country channel campaign_id device_type plan_type
0 1 2024-07-01 2025-03-01 Male Married 31 Medium Germany Paid Ads Paid Ads_C iOS Standard
1 2 2024-04-01 2024-09-01 Male Single 54 Premium Netherlands Referral Referral_B iOS Standard
2 3 2024-05-01 NaN Male Single 34 Medium Poland Paid Ads Paid Ads_A Android Standard
3 4 2024-07-01 NaN Male Married 38 High Belgium Organic Organic_C Android Standard
4 5 2024-03-01 2024-04-01 Male Single 25 Low Sweden Paid Ads Paid Ads_A Android Basic
df.shape
(12000, 12)

Variable Description

The dataset contains customer-level information available at the time of acquisition, along with churn-related variables used later for retention analysis.

Variable Description
user_id Unique customer identifier
acquisition_date Date the customer joined the service
gender Customer gender
marital_status Customer marital status
age Customer age at acquisition
income_segment Income segment derived from age
country Customer country
region Region derived from country (e.g., Western, Eastern Europe)
channel Acquisition channel (Organic, Paid Ads, Referral)
campaign_id Campaign identifier linked to acquisition
device_type Device used at acquisition (Android, iOS, Web)
plan_type Subscription plan (Basic, Standard, Premium)
engagement_score Latent customer quality indicator (simulated)

Exploratory Questions | Who are our customers?

Before performing cohort analysis, we need to understand the dataset and customer composition.

Understand the overall profile of the customer base.

Missing Values

df.isna().sum()
user_id                  0
acquisition_date         0
cancellation_month    3581
gender                   0
marital_status           0
age                      0
income_segment         631
country                  0
channel                  0
campaign_id              0
device_type              0
plan_type                0
dtype: int64

Numerical Summary

df.describe()
user_id acquisition_date age
count 12000.00000 12000 12000.000000
mean 6000.50000 2024-04-15 20:46:48 34.630250
min 1.00000 2024-01-01 00:00:00 18.000000
25% 3000.75000 2024-02-01 00:00:00 28.000000
50% 6000.50000 2024-04-01 00:00:00 34.000000
75% 9000.25000 2024-06-01 00:00:00 41.000000
max 12000.00000 2024-08-01 00:00:00 65.000000
std 3464.24595 NaN 9.514896

Categorical Distributions

categorical_cols = [
    "gender",
    "marital_status",
    "income_segment",
    "country",
    "channel",
    "campaign_id",
    "device_type",
    "plan_type"
]

for col in categorical_cols:
    print(f"\n--- {col} ---")
    print(df[col].value_counts())

--- gender ---
gender
Female    6016
Male      5984
Name: count, dtype: int64

--- marital_status ---
marital_status
Single     7087
Married    4913
Name: count, dtype: int64

--- income_segment ---
income_segment
High       4823
Medium     4267
Low        1627
Premium     652
Name: count, dtype: int64

--- country ---
country
Netherlands       1057
Austria           1031
Poland            1029
Italy             1028
Belgium           1016
Sweden             995
Switzerland        992
Czech Republic     980
Portugal           972
France             969
Spain              968
Germany            963
Name: count, dtype: int64

--- channel ---
channel
Paid Ads    4854
Organic     4690
Referral    2456
Name: count, dtype: int64

--- campaign_id ---
campaign_id
Paid Ads_A    1636
Organic_B     1616
Paid Ads_B    1612
Paid Ads_C    1606
Organic_A     1552
Organic_C     1522
Referral_A     830
Referral_C     816
Referral_B     810
Name: count, dtype: int64

--- device_type ---
device_type
Android    6582
iOS        3664
Web        1754
Name: count, dtype: int64

--- plan_type ---
plan_type
Standard    5979
Basic       3222
Premium     2799
Name: count, dtype: int64

Acquisition Trend

We now complement the descriptive statistics with visual exploration.

acquisition_trend = (
    df.groupby("acquisition_date", as_index=False)
      .agg(customers=("user_id", "count"))
)

fig = px.line(
    acquisition_trend,
    x="acquisition_date",
    y="customers",
    title="Customer Acquisition Trend"
)

fig.show()
Important

As you can see the plot is not that represantative, first we need to convert the acquisition_date into acquisition_month and try to build a dashboard

df["acquisition_month"] = df["acquisition_date"].dt.to_period("M").astype(str)
df.head()
user_id acquisition_date cancellation_month gender marital_status age income_segment country channel campaign_id device_type plan_type acquisition_month
0 1 2024-07-01 2025-03-01 Male Married 31 Medium Germany Paid Ads Paid Ads_C iOS Standard 2024-07
1 2 2024-04-01 2024-09-01 Male Single 54 Premium Netherlands Referral Referral_B iOS Standard 2024-04
2 3 2024-05-01 NaN Male Single 34 Medium Poland Paid Ads Paid Ads_A Android Standard 2024-05
3 4 2024-07-01 NaN Male Married 38 High Belgium Organic Organic_C Android Standard 2024-07
4 5 2024-03-01 2024-04-01 Male Single 25 Low Sweden Paid Ads Paid Ads_A Android Basic 2024-03
acquisition_trend = (
    df.groupby("acquisition_month", as_index=False)
      .agg(customers=("user_id", "count"))
)

fig = px.bar(
    acquisition_trend,
    x="acquisition_month",
    y="customers",
    title="Monthly Acquisition Trend"
)

fig.show()
# Identify peak month
top_month = acquisition_trend.loc[
    acquisition_trend["customers"].idxmax(), "acquisition_month"
]

top_month
'2024-05'

Creating Highlight column to mention

# Highlight column
acquisition_trend["highlight"] = np.where(
    acquisition_trend["acquisition_month"] == top_month,
    "Peak",
    "Other"
)

Now, we can build more informative plot.

fig = px.bar(
    acquisition_trend,
    x="acquisition_month",
    y="customers",
    color = 'highlight',
    title="Monthly Acquisition Trend"
)

fig.show()
Tip

Try to make better axis titles both for x_axis and y_axis

Gender Distribution

gender_counts = df["gender"].value_counts().reset_index()
gender_counts.columns = ["gender", "count"]

fig = px.bar(
    gender_counts,
    x="gender",
    y="count",
    title="Gender Distribution"
)

fig.show()

Marital Status Distribution

marital_counts = df["marital_status"].value_counts().reset_index()
marital_counts.columns = ["marital_status", "count"]

fig = px.bar(
    marital_counts,
    x="marital_status",
    y="count",
    title="Marital Status Distribution"
)

fig.show()

Age Distribution

fig = px.histogram(
    df,
    x="age",
    nbins=30,
    title="Age Distribution"
)

fig.show()

Country Distribution

country_counts = df["country"].value_counts().reset_index()
country_counts.columns = ["country", "count"]

fig = px.bar(
    country_counts,
    x="country",
    y="count",
    title="Country Distribution"
)

fig.show()

Channel Distribution

channel_counts = df["channel"].value_counts().reset_index()
channel_counts.columns = ["channel", "count"]

fig = px.bar(
    channel_counts,
    x="channel",
    y="count",
    title="Channel Distribution"
)

fig.show()

Try Yourself | Interpretation Focus

Tip

This step helps us understand:

  • Whether the dataset is balanced or skewed
  • Which customer segments dominate
  • Whether the data looks realistic and usable

Acquisition Month vs Gender

df["acquisition_month"] = df["acquisition_date"].dt.to_period("M").dt.to_timestamp()

gender_trend = (
    df.groupby(["acquisition_month", "gender"], as_index=False)
      .agg(count=("user_id", "count"))
)

fig = px.bar(
    gender_trend,
    x="acquisition_month",
    y="count",
    color="gender",
    barmode="group",
    title="Acquisition by Gender Over Time"
)

fig.show()

Acquisition Month vs Marital Status

marital_trend = (
    df.groupby(["acquisition_month", "marital_status"], as_index=False)
      .agg(count=("user_id", "count"))
)

fig = px.bar(
    marital_trend,
    x="acquisition_month",
    y="count",
    color="marital_status",
    barmode="group",
    title="Acquisition by Marital Status Over Time"
)

fig.show()

Acquisition Month vs Country

country_trend = (
    df.groupby(["acquisition_month", "country"], as_index=False)
      .agg(count=("user_id", "count"))
)

fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    barmode="group",
    title="Acquisition by Country Over Time"
)

fig.show()
Important

The above plot is not informative at all!

How to fix it?

Based on the country distribituin plot, we can say that we have simetric situation and we can create a mapper and split countries by regions. As a result, the groups would be managemable

Country-Region Mapper
region_map = {
    "France": "Western Europe",
    "Germany": "Western Europe",
    "Netherlands": "Western Europe",
    "Belgium": "Western Europe",

    "Spain": "Southern Europe",
    "Italy": "Southern Europe",
    "Portugal": "Southern Europe",

    "Poland": "Eastern Europe",
    "Czech Republic": "Eastern Europe",
    "Hungary": "Eastern Europe",
    "Romania": "Eastern Europe",

    "Sweden": "Northern Europe",
    "Norway": "Northern Europe",
    "Denmark": "Northern Europe",
    "Finland": "Northern Europe"
}
Preparing the Data
df["region"] = df["country"].map(region_map)

# Step 2: Aggregate
country_trend = (
    df.groupby(["acquisition_month", "region", "country"], as_index=False)
      .agg(count=("user_id", "count"))
)
Visualizing the results | Subplots
fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    # barmode = "group",
    facet_col="region",
    facet_col_wrap=2,
    title="Acquisition by Country (Split by Region)"
)

fig.show()
Visualizing the results | Better Coloring

In plotly there is color_discrete_sequence which we can change and select the one which we like the most

popular ones:

  • px.colors.qualitative.Plotly
  • px.colors.qualitative.Dark24 (better for many categories)
  • px.colors.qualitative.Safe
px.colors.qualitative.Set2
fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    facet_col="region",
    facet_col_wrap=2,
    color_discrete_sequence=px.colors.qualitative.Set2,
    title="Acquisition by Country | 1"
)

fig.show()
px.colors.qualitative.Plotly
fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    facet_col="region",
    facet_col_wrap=2,
    color_discrete_sequence =  px.colors.qualitative.Plotly,
    title="Acquisition by Country | 2 "
)

fig.show()
px.colors.qualitative.Dark24
fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    facet_col="region",
    facet_col_wrap=2,
    color_discrete_sequence = px.colors.qualitative.Dark24,
    title="Acquisition by Country | 3"
)

fig.show()
px.colors.qualitative.Safe
fig = px.bar(
    country_trend,
    x="acquisition_month",
    y="count",
    color="country",
    facet_col="region",
    facet_col_wrap=2,
    color_discrete_sequence=px.colors.qualitative.Safe,
    title="Acquisition by Country (Styled Colors)"
)

fig.show()
Custom Color Pallet

Of course you can create your own colors and put color_discrete_sequence=color_map5

color_map = {
    "France": "#1f77b4",
    "Germany": "#ff7f0e",
    "Netherlands": "#2ca02c",
    "Belgium": "#d62728",
    "Spain": "#9467bd",
    "Italy": "#8c564b",
    "Portugal": "#e377c2",
    "Poland": "#7f7f7f",
    "Czech Republic": "#bcbd22",
    "Sweden": "#17becf"
}

Acquisition Month vs Channel

channel_trend = (
    df.groupby(["acquisition_month", "channel"], as_index=False)
      .agg(count=("user_id", "count"))
)

fig = px.bar(
    channel_trend,
    x="acquisition_month",
    y="count",
    color="channel",
    barmode="group",
    title="Acquisition by Channel Over Time"
)

fig.show()

Acquisition Month vs Plan Type

plan_trend = (
    df.groupby(["acquisition_month", "plan_type"], as_index=False)
      .agg(count=("user_id", "count"))
)

fig = px.bar(
    plan_trend,
    x="acquisition_month",
    y="count",
    color="plan_type",
    barmode="group",
    title="Acquisition by Plan Type Over Time"
)

fig.show()

Try yourself

Tip

If composition differs across months, then:

  • Differences in retention may be driven by who was acquired
  • Not necessarily by when they were acquired

This is a critical insight for cohort analysis.

Exploratory Questions | Do some customer groups appear more likely to churn?

Here, our goal is to Identify whether certain segments are associated with higher churn. First we need to create a derived column which will describe the churn and later compare churn rates across:

  • Gender
  • Marital status
  • Region
  • Channel
  • Plan type

This will also help us:

  • Form hypotheses about retention drivers
  • Understand whether cohort differences may be explained by customer composition

Defining Churn

# churn = 1 if customer has cancellation date, else 0
df["churn"] = df["cancellation_month"].notna().astype(int)

Compute Churn Rate by Segment

Tip

RECALL: Churn rate is defined as:

\[ \text{Churn Rate} = \frac{\text{Number of Churned Customers}}{\text{Total Customers}} \]

Churn Rate by Gender

gender_churn = (
    df.groupby("gender")
      .agg(
          customers=("user_id", "count"),
          churned=("churn", "sum")
      )
      .assign(churn_rate=lambda x: x["churned"] / x["customers"])
      .reset_index()
)

gender_churn
gender customers churned churn_rate
0 Female 6016 4289 0.712932
1 Male 5984 4130 0.690174

Churn Rate by Marital Status

marital_churn = (
    df.groupby("marital_status")
      .agg(
          customers=("user_id", "count"),
          churned=("churn", "sum")
      )
      .assign(churn_rate=lambda x: x["churned"] / x["customers"])
      .reset_index()
)

marital_churn
marital_status customers churned churn_rate
0 Married 4913 3467 0.705679
1 Single 7087 4952 0.698744

Churn Rate by Region

region_churn = (
    df.groupby("region")
      .agg(
          customers=("user_id", "count"),
          churned=("churn", "sum")
      )
      .assign(churn_rate=lambda x: x["churned"] / x["customers"])
      .reset_index()
)

region_churn
region customers churned churn_rate
0 Eastern Europe 2009 1407 0.700348
1 Northern Europe 995 716 0.719598
2 Southern Europe 2968 2053 0.691712
3 Western Europe 4005 2807 0.700874

Churn Rate by Channel

channel_churn = (
    df.groupby("channel")
      .agg(
          customers=("user_id", "count"),
          churned=("churn", "sum")
      )
      .assign(churn_rate=lambda x: x["churned"] / x["customers"])
      .reset_index()
)

channel_churn
channel customers churned churn_rate
0 Organic 4690 3286 0.700640
1 Paid Ads 4854 3397 0.699835
2 Referral 2456 1736 0.706840

Churn Rate by Plan Type

plan_churn = (
    df.groupby("plan_type")
      .agg(
          customers=("user_id", "count"),
          churned=("churn", "sum")
      )
      .assign(churn_rate=lambda x: x["churned"] / x["customers"])
      .reset_index()
)

plan_churn
plan_type customers churned churn_rate
0 Basic 3222 2283 0.708566
1 Premium 2799 1953 0.697749
2 Standard 5979 4183 0.699615

Interpretation

At this stage, we are not yet making causal claims.

Instead, we are:

  • Identifying patterns in churn across segments
  • Detecting potentially risky customer groups
  • Forming hypotheses (e.g., “Do certain plans lead to higher churn?”)

These insights will later connect to cohort analysis, where we evaluate whether differences in retention are driven by:

  • acquisition timing
  • or customer composition

Cohort Analysis

  • Define a 12-month observation window
  • Construct cohort indices
  • Build retention tables
  • Visualize cohort performance using heatmaps

This will allow us to answer the key question:

Do customers acquired at different times behave differently over their lifecycle?

Eventually we want to build something like this:

Step 1: Prepare Cohort Features

df["acquisition_month"] = df["acquisition_date"].dt.to_period("M") # we have done it above, however lets put it here as well in order to make it complete
df["cancellation_month"] = pd.to_datetime(df["cancellation_month"], errors="coerce")

df["cancellation_month"] = df["cancellation_month"].dt.to_period("M")

# Calculate tenure (months until churn)
df["tenure"] = (
    (df["cancellation_month"].dt.year - df["acquisition_month"].dt.year) * 12 +
    (df["cancellation_month"].dt.month - df["acquisition_month"].dt.month)
)
Important

Now we need to make some imputations:

  1. Replace missing cancellations with max observation window (12 months)
  2. Replace the numbers which are above 12 with 12
df["tenure"] = df["tenure"].fillna(12)


df["tenure"] = df["tenure"].clip(lower=0, upper=12)

Step 2: Build Cohort Table

cohort_data = (
    df.groupby(["acquisition_month", "tenure"])
      .agg(users=("user_id", "count"))
      .reset_index()
)

cohort_data.head()
acquisition_month tenure users
0 2024-01 0 461
1 2024-01 4 4
2 2024-01 5 7
3 2024-01 6 37
4 2024-01 7 70

Step 3: Creating Cohort Size

cohort_size = (
    cohort_data[cohort_data["tenure"] == 0]
    .rename(columns={"users": "cohort_size"})
    [["acquisition_month", "cohort_size"]]
)

cohort_size.head()
acquisition_month cohort_size
0 2024-01 461
10 2024-02 470
22 2024-03 415
33 2024-04 439
44 2024-05 456

Step 4: Merge Cohort Size Back

cohort_data = cohort_data.merge(
    cohort_size,
    on="acquisition_month"
)

Step 5: Calculating retention rate

cohort_data["retention_rate"] = (
    cohort_data["users"] / cohort_data["cohort_size"]
)

Step 6: Pivotting

cohort_pivot = cohort_data.pivot(
    index="acquisition_month",
    columns="tenure",
    values="retention_rate"
)

cohort_pivot.head()
tenure 0 1 2 3 4 5 6 7 8 9 10 11 12
acquisition_month
2024-01 1.0 NaN NaN NaN 0.008677 0.015184 0.080260 0.151844 0.271150 0.403471 0.433839 0.388286 0.505423
2024-02 1.0 NaN 0.004255 0.014894 0.068085 0.153191 0.270213 0.410638 0.438298 0.417021 0.278723 0.131915 0.063830
2024-03 1.0 0.585542 0.412048 0.448193 0.486747 0.318072 0.175904 0.086747 0.009639 NaN 0.002410 0.004819 NaN
2024-04 1.0 0.280182 0.280182 0.451025 0.464692 0.396355 0.291572 0.179954 0.063781 0.025057 0.004556 NaN NaN
2024-05 1.0 0.032895 0.074561 0.162281 0.250000 0.425439 0.482456 0.445175 0.278509 0.142544 0.061404 0.017544 0.006579

Step 7: Visualizing

# Convert index to string for cleaner display
cohort_pivot_plot = cohort_pivot.copy()
cohort_pivot_plot.index = cohort_pivot_plot.index.astype(str)
cohort_pivot_plot = cohort_pivot_plot.sort_index(ascending=False)


fig = px.imshow(
    cohort_pivot_plot,
    aspect="auto",
    text_auto=".0%",
    color_continuous_scale="Blues"
)

fig.update_layout(
    title="Cohort Retention Heatmap",
    xaxis_title="Months Since Acquisition",
    yaxis_title="Acquisition Month"
)

fig.show()