import pandas as pd
import numpy as np
from scipy import stats
import plotly.express as pxSession 11: Cohort Analysis
Cohort Analysis
Cohort analysis is a technique where you group users/customers into cohorts (groups with a shared characteristic) and track how their behavior evolves over time.
A cohort is typically defined by something like:
- Signup date (most common → acquisition cohorts)
- First purchase
- Marketing channel
- Behavior (e.g., users who used feature X)
Instead of looking at averages across all users, cohort analysis lets you see how different groups behave differently over time.
For the groupings that are not time-dependent, the term segment is more appropriate to use.
Why Cohort Analysis Matters
Aggregate metrics can be misleading. For example if Overall retention = 70% (looks good)
But:
- Old users → 90% retention
- New users → 40% retention
Without cohorts, you miss the problem completely.
Cohort analysis helps you:
- Identify when users churn
- Understand why some cohorts perform better
- Measure impact of campaigns or product changes
- Separate growth (new users) from retention (existing users)
Simple Mental Model
Think of cohorts as batches of users starting at the same time:
| Cohort (Signup Month) | Month 0 | Month 1 | Month 2 | Month 3 |
|---|---|---|---|---|
| Jan users | 100% | 70% | 55% | 40% |
| Feb users | 100% | 60% | 45% | 30% |
| Mar users | 100% | 80% | 65% | 50% |
Insights:
- February cohort performs worse
- March cohort performs better
- Something changed in the system
Types of Cohorts
Acquisition Cohorts
- Grouped by start date
- Used for retention analysis
Behavioral Cohorts
- Grouped by actions
- Example: users who completed onboarding vs not
Segment-Based Cohorts
- Grouped by attributes:
- Country
- Gender
- Channel
- Country
Key Metrics
Retention Rate
\[ Retention = \frac{\text{Active Users at time t}}{\text{Users in cohort at start}} \]
Churn Rate
\[ Churn = 1 - Retention \]
Lifetime Value (LTV)
Total value generated by a cohort over time.
Cohort Analysis | Steps
- Define a cohort: users who signed up in January (monthly)
- Track behavior over time: Activity, purchases, churn
- Aggregate by time periods: Month 1, Month 2, Month 3…
- Visualize:
- Heatmap
- Line plots
- Heatmap
- Compare cohorts: Identify patterns and anomalies
Typical Output (Heatmap Structure)
- Rows → Cohorts (by acquisition month)
- Columns → Time since acquisition
- Values → Retention %
Example interpretation:
- Early drop-off → onboarding issue
- Late drop-off → product value issue

Where It’s Used
- SaaS (retention, churn)
- E-commerce (repeat purchases)
- Telecom (churn analysis, campaigns)
- Mobile apps (engagement tracking)
In other words, any serivice providing business.
Key Insight
Cohort analysis answers:
What happens to users after they join, not just overall performance?
This makes it one of the most powerful tools in product and marketing analytics.
Case Study: Cohort Analysis
Business Context
A European subscription-based digital service is analyzing customer behavior to better understand retention dynamics over time.
The company acquires customers through multiple channels such as Organic traffic, Paid Advertising, and Referrals. While overall growth has been steady, management has observed that:
- Some acquisition periods seem to generate more loyal customers
- Other periods show faster customer drop-off (churn)
The business wants to understand whether these differences are real and what might be driving them.
Business Objective
Identify which acquisition cohorts perform better or worse over time and understand the possible drivers behind these differences.
Analytical Objective
To achieve this, we will:
- Explore customer composition at acquisition
- Compare cohorts across different months
- Introduce churn behavior
- Build a cohort-based retention analysis
- Visualize retention using cohort curves and heatmaps
Loading Packages
Reading the Data
df = pd.read_csv('../data/cohort/cohort_analysis.csv',
parse_dates = ["acquisition_date"])
df.head()| user_id | acquisition_date | cancellation_month | gender | marital_status | age | income_segment | country | channel | campaign_id | device_type | plan_type | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 2024-07-01 | 2025-03-01 | Male | Married | 31 | Medium | Germany | Paid Ads | Paid Ads_C | iOS | Standard |
| 1 | 2 | 2024-04-01 | 2024-09-01 | Male | Single | 54 | Premium | Netherlands | Referral | Referral_B | iOS | Standard |
| 2 | 3 | 2024-05-01 | NaN | Male | Single | 34 | Medium | Poland | Paid Ads | Paid Ads_A | Android | Standard |
| 3 | 4 | 2024-07-01 | NaN | Male | Married | 38 | High | Belgium | Organic | Organic_C | Android | Standard |
| 4 | 5 | 2024-03-01 | 2024-04-01 | Male | Single | 25 | Low | Sweden | Paid Ads | Paid Ads_A | Android | Basic |
df.shape(12000, 12)
Variable Description
The dataset contains customer-level information available at the time of acquisition, along with churn-related variables used later for retention analysis.
| Variable | Description |
|---|---|
user_id |
Unique customer identifier |
acquisition_date |
Date the customer joined the service |
gender |
Customer gender |
marital_status |
Customer marital status |
age |
Customer age at acquisition |
income_segment |
Income segment derived from age |
country |
Customer country |
region |
Region derived from country (e.g., Western, Eastern Europe) |
channel |
Acquisition channel (Organic, Paid Ads, Referral) |
campaign_id |
Campaign identifier linked to acquisition |
device_type |
Device used at acquisition (Android, iOS, Web) |
plan_type |
Subscription plan (Basic, Standard, Premium) |
engagement_score |
Latent customer quality indicator (simulated) |
Exploratory Questions | Who are our customers?
Before performing cohort analysis, we need to understand the dataset and customer composition.
Understand the overall profile of the customer base.
Missing Values
df.isna().sum()user_id 0
acquisition_date 0
cancellation_month 3581
gender 0
marital_status 0
age 0
income_segment 631
country 0
channel 0
campaign_id 0
device_type 0
plan_type 0
dtype: int64
Numerical Summary
df.describe()| user_id | acquisition_date | age | |
|---|---|---|---|
| count | 12000.00000 | 12000 | 12000.000000 |
| mean | 6000.50000 | 2024-04-15 20:46:48 | 34.630250 |
| min | 1.00000 | 2024-01-01 00:00:00 | 18.000000 |
| 25% | 3000.75000 | 2024-02-01 00:00:00 | 28.000000 |
| 50% | 6000.50000 | 2024-04-01 00:00:00 | 34.000000 |
| 75% | 9000.25000 | 2024-06-01 00:00:00 | 41.000000 |
| max | 12000.00000 | 2024-08-01 00:00:00 | 65.000000 |
| std | 3464.24595 | NaN | 9.514896 |
Categorical Distributions
categorical_cols = [
"gender",
"marital_status",
"income_segment",
"country",
"channel",
"campaign_id",
"device_type",
"plan_type"
]
for col in categorical_cols:
print(f"\n--- {col} ---")
print(df[col].value_counts())
--- gender ---
gender
Female 6016
Male 5984
Name: count, dtype: int64
--- marital_status ---
marital_status
Single 7087
Married 4913
Name: count, dtype: int64
--- income_segment ---
income_segment
High 4823
Medium 4267
Low 1627
Premium 652
Name: count, dtype: int64
--- country ---
country
Netherlands 1057
Austria 1031
Poland 1029
Italy 1028
Belgium 1016
Sweden 995
Switzerland 992
Czech Republic 980
Portugal 972
France 969
Spain 968
Germany 963
Name: count, dtype: int64
--- channel ---
channel
Paid Ads 4854
Organic 4690
Referral 2456
Name: count, dtype: int64
--- campaign_id ---
campaign_id
Paid Ads_A 1636
Organic_B 1616
Paid Ads_B 1612
Paid Ads_C 1606
Organic_A 1552
Organic_C 1522
Referral_A 830
Referral_C 816
Referral_B 810
Name: count, dtype: int64
--- device_type ---
device_type
Android 6582
iOS 3664
Web 1754
Name: count, dtype: int64
--- plan_type ---
plan_type
Standard 5979
Basic 3222
Premium 2799
Name: count, dtype: int64
Acquisition Trend
We now complement the descriptive statistics with visual exploration.
acquisition_trend = (
df.groupby("acquisition_date", as_index=False)
.agg(customers=("user_id", "count"))
)
fig = px.line(
acquisition_trend,
x="acquisition_date",
y="customers",
title="Customer Acquisition Trend"
)
fig.show()As you can see the plot is not that represantative, first we need to convert the acquisition_date into acquisition_month and try to build a dashboard
df["acquisition_month"] = df["acquisition_date"].dt.to_period("M").astype(str)
df.head()| user_id | acquisition_date | cancellation_month | gender | marital_status | age | income_segment | country | channel | campaign_id | device_type | plan_type | acquisition_month | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 2024-07-01 | 2025-03-01 | Male | Married | 31 | Medium | Germany | Paid Ads | Paid Ads_C | iOS | Standard | 2024-07 |
| 1 | 2 | 2024-04-01 | 2024-09-01 | Male | Single | 54 | Premium | Netherlands | Referral | Referral_B | iOS | Standard | 2024-04 |
| 2 | 3 | 2024-05-01 | NaN | Male | Single | 34 | Medium | Poland | Paid Ads | Paid Ads_A | Android | Standard | 2024-05 |
| 3 | 4 | 2024-07-01 | NaN | Male | Married | 38 | High | Belgium | Organic | Organic_C | Android | Standard | 2024-07 |
| 4 | 5 | 2024-03-01 | 2024-04-01 | Male | Single | 25 | Low | Sweden | Paid Ads | Paid Ads_A | Android | Basic | 2024-03 |
acquisition_trend = (
df.groupby("acquisition_month", as_index=False)
.agg(customers=("user_id", "count"))
)
fig = px.bar(
acquisition_trend,
x="acquisition_month",
y="customers",
title="Monthly Acquisition Trend"
)
fig.show()# Identify peak month
top_month = acquisition_trend.loc[
acquisition_trend["customers"].idxmax(), "acquisition_month"
]
top_month'2024-05'
Creating Highlight column to mention
# Highlight column
acquisition_trend["highlight"] = np.where(
acquisition_trend["acquisition_month"] == top_month,
"Peak",
"Other"
)Now, we can build more informative plot.
fig = px.bar(
acquisition_trend,
x="acquisition_month",
y="customers",
color = 'highlight',
title="Monthly Acquisition Trend"
)
fig.show()Try to make better axis titles both for x_axis and y_axis
Gender Distribution
gender_counts = df["gender"].value_counts().reset_index()
gender_counts.columns = ["gender", "count"]
fig = px.bar(
gender_counts,
x="gender",
y="count",
title="Gender Distribution"
)
fig.show()Marital Status Distribution
marital_counts = df["marital_status"].value_counts().reset_index()
marital_counts.columns = ["marital_status", "count"]
fig = px.bar(
marital_counts,
x="marital_status",
y="count",
title="Marital Status Distribution"
)
fig.show()Age Distribution
fig = px.histogram(
df,
x="age",
nbins=30,
title="Age Distribution"
)
fig.show()Country Distribution
country_counts = df["country"].value_counts().reset_index()
country_counts.columns = ["country", "count"]
fig = px.bar(
country_counts,
x="country",
y="count",
title="Country Distribution"
)
fig.show()Channel Distribution
channel_counts = df["channel"].value_counts().reset_index()
channel_counts.columns = ["channel", "count"]
fig = px.bar(
channel_counts,
x="channel",
y="count",
title="Channel Distribution"
)
fig.show()Try Yourself | Interpretation Focus
This step helps us understand:
- Whether the dataset is balanced or skewed
- Which customer segments dominate
- Whether the data looks realistic and usable
Acquisition Month vs Gender
df["acquisition_month"] = df["acquisition_date"].dt.to_period("M").dt.to_timestamp()
gender_trend = (
df.groupby(["acquisition_month", "gender"], as_index=False)
.agg(count=("user_id", "count"))
)
fig = px.bar(
gender_trend,
x="acquisition_month",
y="count",
color="gender",
barmode="group",
title="Acquisition by Gender Over Time"
)
fig.show()Acquisition Month vs Marital Status
marital_trend = (
df.groupby(["acquisition_month", "marital_status"], as_index=False)
.agg(count=("user_id", "count"))
)
fig = px.bar(
marital_trend,
x="acquisition_month",
y="count",
color="marital_status",
barmode="group",
title="Acquisition by Marital Status Over Time"
)
fig.show()Acquisition Month vs Country
country_trend = (
df.groupby(["acquisition_month", "country"], as_index=False)
.agg(count=("user_id", "count"))
)
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
barmode="group",
title="Acquisition by Country Over Time"
)
fig.show()The above plot is not informative at all!
How to fix it?
Based on the country distribituin plot, we can say that we have simetric situation and we can create a mapper and split countries by regions. As a result, the groups would be managemable
Country-Region Mapper
region_map = {
"France": "Western Europe",
"Germany": "Western Europe",
"Netherlands": "Western Europe",
"Belgium": "Western Europe",
"Spain": "Southern Europe",
"Italy": "Southern Europe",
"Portugal": "Southern Europe",
"Poland": "Eastern Europe",
"Czech Republic": "Eastern Europe",
"Hungary": "Eastern Europe",
"Romania": "Eastern Europe",
"Sweden": "Northern Europe",
"Norway": "Northern Europe",
"Denmark": "Northern Europe",
"Finland": "Northern Europe"
}Preparing the Data
df["region"] = df["country"].map(region_map)
# Step 2: Aggregate
country_trend = (
df.groupby(["acquisition_month", "region", "country"], as_index=False)
.agg(count=("user_id", "count"))
)Visualizing the results | Subplots
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
# barmode = "group",
facet_col="region",
facet_col_wrap=2,
title="Acquisition by Country (Split by Region)"
)
fig.show()Visualizing the results | Better Coloring
In plotly there is color_discrete_sequence which we can change and select the one which we like the most
popular ones:
px.colors.qualitative.Plotlypx.colors.qualitative.Dark24(better for many categories)px.colors.qualitative.Safe
px.colors.qualitative.Set2
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
facet_col="region",
facet_col_wrap=2,
color_discrete_sequence=px.colors.qualitative.Set2,
title="Acquisition by Country | 1"
)
fig.show()px.colors.qualitative.Plotly
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
facet_col="region",
facet_col_wrap=2,
color_discrete_sequence = px.colors.qualitative.Plotly,
title="Acquisition by Country | 2 "
)
fig.show()px.colors.qualitative.Dark24
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
facet_col="region",
facet_col_wrap=2,
color_discrete_sequence = px.colors.qualitative.Dark24,
title="Acquisition by Country | 3"
)
fig.show()px.colors.qualitative.Safe
fig = px.bar(
country_trend,
x="acquisition_month",
y="count",
color="country",
facet_col="region",
facet_col_wrap=2,
color_discrete_sequence=px.colors.qualitative.Safe,
title="Acquisition by Country (Styled Colors)"
)
fig.show()Custom Color Pallet
Of course you can create your own colors and put color_discrete_sequence=color_map5
color_map = {
"France": "#1f77b4",
"Germany": "#ff7f0e",
"Netherlands": "#2ca02c",
"Belgium": "#d62728",
"Spain": "#9467bd",
"Italy": "#8c564b",
"Portugal": "#e377c2",
"Poland": "#7f7f7f",
"Czech Republic": "#bcbd22",
"Sweden": "#17becf"
}Acquisition Month vs Channel
channel_trend = (
df.groupby(["acquisition_month", "channel"], as_index=False)
.agg(count=("user_id", "count"))
)
fig = px.bar(
channel_trend,
x="acquisition_month",
y="count",
color="channel",
barmode="group",
title="Acquisition by Channel Over Time"
)
fig.show()Acquisition Month vs Plan Type
plan_trend = (
df.groupby(["acquisition_month", "plan_type"], as_index=False)
.agg(count=("user_id", "count"))
)
fig = px.bar(
plan_trend,
x="acquisition_month",
y="count",
color="plan_type",
barmode="group",
title="Acquisition by Plan Type Over Time"
)
fig.show()Try yourself
If composition differs across months, then:
- Differences in retention may be driven by who was acquired
- Not necessarily by when they were acquired
This is a critical insight for cohort analysis.
Exploratory Questions | Do some customer groups appear more likely to churn?
Here, our goal is to Identify whether certain segments are associated with higher churn. First we need to create a derived column which will describe the churn and later compare churn rates across:
- Gender
- Marital status
- Region
- Channel
- Plan type
This will also help us:
- Form hypotheses about retention drivers
- Understand whether cohort differences may be explained by customer composition
Defining Churn
# churn = 1 if customer has cancellation date, else 0
df["churn"] = df["cancellation_month"].notna().astype(int)Compute Churn Rate by Segment
RECALL: Churn rate is defined as:
\[ \text{Churn Rate} = \frac{\text{Number of Churned Customers}}{\text{Total Customers}} \]
Churn Rate by Gender
gender_churn = (
df.groupby("gender")
.agg(
customers=("user_id", "count"),
churned=("churn", "sum")
)
.assign(churn_rate=lambda x: x["churned"] / x["customers"])
.reset_index()
)
gender_churn| gender | customers | churned | churn_rate | |
|---|---|---|---|---|
| 0 | Female | 6016 | 4289 | 0.712932 |
| 1 | Male | 5984 | 4130 | 0.690174 |
Churn Rate by Marital Status
marital_churn = (
df.groupby("marital_status")
.agg(
customers=("user_id", "count"),
churned=("churn", "sum")
)
.assign(churn_rate=lambda x: x["churned"] / x["customers"])
.reset_index()
)
marital_churn| marital_status | customers | churned | churn_rate | |
|---|---|---|---|---|
| 0 | Married | 4913 | 3467 | 0.705679 |
| 1 | Single | 7087 | 4952 | 0.698744 |
Churn Rate by Region
region_churn = (
df.groupby("region")
.agg(
customers=("user_id", "count"),
churned=("churn", "sum")
)
.assign(churn_rate=lambda x: x["churned"] / x["customers"])
.reset_index()
)
region_churn| region | customers | churned | churn_rate | |
|---|---|---|---|---|
| 0 | Eastern Europe | 2009 | 1407 | 0.700348 |
| 1 | Northern Europe | 995 | 716 | 0.719598 |
| 2 | Southern Europe | 2968 | 2053 | 0.691712 |
| 3 | Western Europe | 4005 | 2807 | 0.700874 |
Churn Rate by Channel
channel_churn = (
df.groupby("channel")
.agg(
customers=("user_id", "count"),
churned=("churn", "sum")
)
.assign(churn_rate=lambda x: x["churned"] / x["customers"])
.reset_index()
)
channel_churn| channel | customers | churned | churn_rate | |
|---|---|---|---|---|
| 0 | Organic | 4690 | 3286 | 0.700640 |
| 1 | Paid Ads | 4854 | 3397 | 0.699835 |
| 2 | Referral | 2456 | 1736 | 0.706840 |
Churn Rate by Plan Type
plan_churn = (
df.groupby("plan_type")
.agg(
customers=("user_id", "count"),
churned=("churn", "sum")
)
.assign(churn_rate=lambda x: x["churned"] / x["customers"])
.reset_index()
)
plan_churn| plan_type | customers | churned | churn_rate | |
|---|---|---|---|---|
| 0 | Basic | 3222 | 2283 | 0.708566 |
| 1 | Premium | 2799 | 1953 | 0.697749 |
| 2 | Standard | 5979 | 4183 | 0.699615 |
Interpretation
At this stage, we are not yet making causal claims.
Instead, we are:
- Identifying patterns in churn across segments
- Detecting potentially risky customer groups
- Forming hypotheses (e.g., “Do certain plans lead to higher churn?”)
These insights will later connect to cohort analysis, where we evaluate whether differences in retention are driven by:
- acquisition timing
- or customer composition
Cohort Analysis
- Define a 12-month observation window
- Construct cohort indices
- Build retention tables
- Visualize cohort performance using heatmaps
This will allow us to answer the key question:
Do customers acquired at different times behave differently over their lifecycle?
Eventually we want to build something like this:

Step 1: Prepare Cohort Features
df["acquisition_month"] = df["acquisition_date"].dt.to_period("M") # we have done it above, however lets put it here as well in order to make it complete
df["cancellation_month"] = pd.to_datetime(df["cancellation_month"], errors="coerce")
df["cancellation_month"] = df["cancellation_month"].dt.to_period("M")
# Calculate tenure (months until churn)
df["tenure"] = (
(df["cancellation_month"].dt.year - df["acquisition_month"].dt.year) * 12 +
(df["cancellation_month"].dt.month - df["acquisition_month"].dt.month)
)Now we need to make some imputations:
- Replace missing cancellations with max observation window (12 months)
- Replace the numbers which are above 12 with 12
df["tenure"] = df["tenure"].fillna(12)
df["tenure"] = df["tenure"].clip(lower=0, upper=12)Step 2: Build Cohort Table
cohort_data = (
df.groupby(["acquisition_month", "tenure"])
.agg(users=("user_id", "count"))
.reset_index()
)
cohort_data.head()| acquisition_month | tenure | users | |
|---|---|---|---|
| 0 | 2024-01 | 0 | 461 |
| 1 | 2024-01 | 4 | 4 |
| 2 | 2024-01 | 5 | 7 |
| 3 | 2024-01 | 6 | 37 |
| 4 | 2024-01 | 7 | 70 |
Step 3: Creating Cohort Size
cohort_size = (
cohort_data[cohort_data["tenure"] == 0]
.rename(columns={"users": "cohort_size"})
[["acquisition_month", "cohort_size"]]
)
cohort_size.head()| acquisition_month | cohort_size | |
|---|---|---|
| 0 | 2024-01 | 461 |
| 10 | 2024-02 | 470 |
| 22 | 2024-03 | 415 |
| 33 | 2024-04 | 439 |
| 44 | 2024-05 | 456 |
Step 4: Merge Cohort Size Back
cohort_data = cohort_data.merge(
cohort_size,
on="acquisition_month"
)Step 5: Calculating retention rate
cohort_data["retention_rate"] = (
cohort_data["users"] / cohort_data["cohort_size"]
)Step 6: Pivotting
cohort_pivot = cohort_data.pivot(
index="acquisition_month",
columns="tenure",
values="retention_rate"
)
cohort_pivot.head()| tenure | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| acquisition_month | |||||||||||||
| 2024-01 | 1.0 | NaN | NaN | NaN | 0.008677 | 0.015184 | 0.080260 | 0.151844 | 0.271150 | 0.403471 | 0.433839 | 0.388286 | 0.505423 |
| 2024-02 | 1.0 | NaN | 0.004255 | 0.014894 | 0.068085 | 0.153191 | 0.270213 | 0.410638 | 0.438298 | 0.417021 | 0.278723 | 0.131915 | 0.063830 |
| 2024-03 | 1.0 | 0.585542 | 0.412048 | 0.448193 | 0.486747 | 0.318072 | 0.175904 | 0.086747 | 0.009639 | NaN | 0.002410 | 0.004819 | NaN |
| 2024-04 | 1.0 | 0.280182 | 0.280182 | 0.451025 | 0.464692 | 0.396355 | 0.291572 | 0.179954 | 0.063781 | 0.025057 | 0.004556 | NaN | NaN |
| 2024-05 | 1.0 | 0.032895 | 0.074561 | 0.162281 | 0.250000 | 0.425439 | 0.482456 | 0.445175 | 0.278509 | 0.142544 | 0.061404 | 0.017544 | 0.006579 |
Step 7: Visualizing
# Convert index to string for cleaner display
cohort_pivot_plot = cohort_pivot.copy()
cohort_pivot_plot.index = cohort_pivot_plot.index.astype(str)
cohort_pivot_plot = cohort_pivot_plot.sort_index(ascending=False)
fig = px.imshow(
cohort_pivot_plot,
aspect="auto",
text_auto=".0%",
color_continuous_scale="Blues"
)
fig.update_layout(
title="Cohort Retention Heatmap",
xaxis_title="Months Since Acquisition",
yaxis_title="Acquisition Month"
)
fig.show()