Matching
0. Import libraries
[1]:
from lightautoml.addons.hypex import Matcher
1. Create or upload your dataset
[2]:
from lightautoml.addons.hypex.utils.tutorial_data_creation import create_test_data
[3]:
df = create_test_data(num_users=10000, rs=42, na_step=45, nan_cols=['age', 'gender'])
df
[3]:
| user_id | signup_month | treat | pre_spends | post_spends | age | gender | industry | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 0 | 504.5 | 422.777778 | NaN | F | Logistics |
| 1 | 1 | 4 | 1 | 500.0 | 506.333333 | 51.0 | NaN | E-commerce |
| 2 | 2 | 0 | 0 | 485.0 | 434.000000 | 56.0 | F | Logistics |
| 3 | 3 | 8 | 1 | 452.0 | 468.111111 | 46.0 | M | E-commerce |
| 4 | 4 | 0 | 0 | 488.5 | 420.111111 | 56.0 | M | Logistics |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 9995 | 9995 | 2 | 1 | 482.0 | 501.666667 | 31.0 | M | Logistics |
| 9996 | 9996 | 0 | 0 | 453.0 | 406.888889 | 53.0 | M | Logistics |
| 9997 | 9997 | 0 | 0 | 461.0 | 415.111111 | 52.0 | F | E-commerce |
| 9998 | 9998 | 10 | 1 | 491.5 | 439.222222 | 22.0 | M | E-commerce |
| 9999 | 9999 | 2 | 1 | 481.0 | 517.222222 | 53.0 | M | E-commerce |
10000 rows × 8 columns
[4]:
df.columns
[4]:
Index(['user_id', 'signup_month', 'treat', 'pre_spends', 'post_spends', 'age',
'gender', 'industry'],
dtype='object')
[5]:
df['treat'].value_counts()
[5]:
treat
0 5002
1 4998
Name: count, dtype: int64
[6]:
df['gender'].isna().sum()
[6]:
223
2. Matching
2.0 Init params
[7]:
info_col = ['user_id']
outcome = 'post_spends'
treatment = 'treat'
2.1 Simple matching
[8]:
# Standard model with base parameters
model = Matcher(input_data=df, outcome=outcome, treatment=treatment, info_col=info_col,
algo='fast')
[18.12.2024 22:05:15 | hypex | INFO]: Number of NaN values filled with zeros: 446
Feature selection models the significance of features for the accuracy of target approximation. However, it does not rule out the possibility of overlooked features, the complex impact of features on target description, or the significance of features from a business logic perspective. The algorithm will not function correctly if there are data leaks. Points to consider when selecting features:
Data leaks - these should not be present.
Influence on treatment distribution - features should not affect the treatment distribution.
The target should be describable by features.
All features significantly affecting the target should be included.
The business rationale of features.
The feature selection function can be useful for addressing these tasks, but it does not solve them nor does it absolve the user of the responsibility for their selection, nor does it justify it.
[9]:
selected_features = model.feature_select()
selected_features
/Users/tikhomirov/PycharmProjects/Sber_New/LightAutoML/.venv/lib/python3.10/site-packages/hypex/selectors/feature_selector.py:42: UserWarning: FeatureSelector does not rule out the possibility of overlooked features, the complex impact of features on target description, or the significance of features from a business logic perspective.
warnings.warn(
[9]:
| rank | |
|---|---|
| signup_month | 1 |
| pre_spends | 2 |
| age | 3 |
| gender_F | 4 |
| gender_M | 5 |
| industry_Logistics | 6 |
[10]:
chosen_features = selected_features[:4].index
chosen_features
[10]:
Index(['signup_month', 'pre_spends', 'age', 'gender_F'], dtype='object')
[11]:
results, quality_results, df_matched = model.estimate(features=chosen_features)
[18.12.2024 22:05:18 | Faiss hypex | INFO]: The entry of bias into the ATT is 0.1%
What is necessary to check in this table?¶
effect size in ATT - it is effect in treated group
standart error shows how accurately the parameter estimate corresponds to the true value in the total population
p-value shows the measure of randomness of the sample (in this example: A p-value of 0.0 means that there is a 0% (percentage probability) that the result is due to randomness)
ci (confidence interval) - the interval that covers the estimated parameter (ATT, ATC, ATE) with a given probability
[12]:
results
[12]:
| effect_size | std_err | p-val | ci_lower | ci_upper | outcome | |
|---|---|---|---|---|---|---|
| ATE | 82.636764 | 2.339283 | 0.0 | 78.051770 | 87.221757 | post_spends |
| ATC | 101.687566 | 4.571439 | 0.0 | 92.727546 | 110.647585 | post_spends |
| ATT | 63.570714 | 0.694726 | 0.0 | 62.209051 | 64.932378 | post_spends |
Variable quality_results contains:
results of psi test
resulnt of Kolmogorov-Smirnov test
results of smd
number of repeats
Rules:
PSI < 0.1 - No change. You can continue using existing model.
PSI >=0.1 but less than 0.2 - Slight change is required.
PSI >=0.2 - Significant change is required. Ideally, you should not use this model any more.
Rules:
Smaller than 0.1. For a randomized trial, the smd between all the covariates should typically fall into this bucket.
0.1 - 0.2. Not necessarily balanced, but small enough that people are usually not too worried about them.
0.2. Values that are greater than this threshold are considered seriously imbalanced.
[13]:
quality_results.keys()
[13]:
dict_keys(['psi', 'ks_test', 'smd', 'repeats'])
Kolmogorov-Smirnov test¶ the distribution of one sample is compared with the distribution of the second sample and it is decided whether the samples have the same or different distribution.
Table shows the p-value results of the test. If p-value < 0.05 we reject the null hypothesis and we have enough evidence to say that the sample data do not have the same distribution.
[14]:
quality_results['ks_test']
[14]:
| match_control_to_treat | match_treat_to_control | |
|---|---|---|
| age | 1.418144e-01 | 2.209790e-07 |
| pre_spends | 2.348715e-264 | 3.829212e-19 |
| signup_month | 0.000000e+00 | 0.000000e+00 |
[15]:
df_matched
[15]:
| index | signup_month | pre_spends | age | gender_F | gender_M | industry_Logistics | signup_month_matched | pre_spends_matched | age_matched | gender_F_matched | gender_M_matched | industry_Logistics_matched | index_matched | post_spends | post_spends_matched | post_spends_matched_bias | treat | treat_matched | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 956 | 8 | 487.0 | 22.0 | 0 | 1 | 0 | 0.0 | 506.0 | 23.0 | 0.0 | 1.0 | 1.0 | [6352] | 462.222222 | 408.333333 | 54.004165 | 1 | 0 |
| 1 | 7966 | 5 | 471.5 | 69.0 | 0 | 0 | 1 | 0.0 | 483.5 | 69.0 | 0.0 | 1.0 | 1.0 | [5349] | 505.222222 | 415.222222 | 90.068667 | 1 | 0 |
| 2 | 7231 | 4 | 487.0 | 62.0 | 1 | 0 | 1 | 0.0 | 496.5 | 62.0 | 1.0 | 0.0 | 1.0 | [8654] | 503.555556 | 428.777778 | 74.832139 | 1 | 0 |
| 3 | 1443 | 1 | 517.0 | 36.0 | 1 | 0 | 1 | 0.0 | 520.5 | 36.0 | 1.0 | 0.0 | 1.0 | [7881] | 526.111111 | 423.111111 | 103.020028 | 1 | 0 |
| 4 | 7973 | 10 | 501.0 | 65.0 | 1 | 0 | 0 | 0.0 | 525.0 | 65.0 | 1.0 | 0.0 | 0.0 | [7333] | 440.444444 | 416.666667 | 23.915112 | 1 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4997 | 2926 | 0 | 486.0 | 45.0 | 0 | 0 | 0 | 2.0 | 479.5 | 46.0 | 0.0 | 1.0 | 1.0 | [971] | 421.777778 | 509.888889 | 89.948512 | 0 | 1 |
| 4998 | 1656 | 0 | 465.5 | 39.0 | 1 | 0 | 1 | 2.0 | 460.0 | 40.0 | 1.0 | 0.0 | 1.0 | [6730] | 417.666667 | 502.222222 | 86.443231 | 0 | 1 |
| 4999 | 2812 | 0 | 451.0 | 53.0 | 1 | 0 | 1 | 2.0 | 441.5 | 57.0 | 1.0 | 0.0 | 1.0 | [6610] | 417.333333 | 521.555556 | 106.002291 | 0 | 1 |
| 5000 | 2813 | 0 | 504.5 | 69.0 | 0 | 1 | 0 | 1.0 | 514.0 | 65.0 | 0.0 | 1.0 | 0.0 | [9319] | 422.888889 | 528.555556 | 107.086127 | 0 | 1 |
| 5001 | 1297 | 0 | 495.5 | 43.0 | 1 | 0 | 1 | 1.0 | 496.5 | 38.0 | 1.0 | 0.0 | 0.0 | [3004] | 415.111111 | 532.555556 | 118.405412 | 0 | 1 |
10000 rows × 19 columns
[16]:
df_matched[df_matched['industry_Logistics'] != df_matched['industry_Logistics_matched']]
[16]:
| index | signup_month | pre_spends | age | gender_F | gender_M | industry_Logistics | signup_month_matched | pre_spends_matched | age_matched | gender_F_matched | gender_M_matched | industry_Logistics_matched | index_matched | post_spends | post_spends_matched | post_spends_matched_bias | treat | treat_matched | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 956 | 8 | 487.0 | 22.0 | 0 | 1 | 0 | 0.0 | 506.0 | 23.0 | 0.0 | 1.0 | 1.0 | [6352] | 462.222222 | 408.333333 | 54.004165 | 1 | 0 |
| 5 | 8108 | 7 | 482.0 | 54.0 | 0 | 1 | 1 | 0.0 | 499.0 | 55.0 | 0.0 | 1.0 | 0.0 | [1742] | 478.333333 | 411.222222 | 67.214942 | 1 | 0 |
| 6 | 7228 | 7 | 493.0 | 35.0 | 1 | 0 | 1 | 0.0 | 511.0 | 36.0 | 1.0 | 0.0 | 0.0 | [5653] | 492.555556 | 419.444444 | 73.220665 | 1 | 0 |
| 7 | 7968 | 10 | 512.5 | 64.0 | 1 | 0 | 0 | 0.0 | 540.0 | 63.0 | 1.0 | 0.0 | 1.0 | [4470] | 436.777778 | 418.666667 | 18.261920 | 1 | 0 |
| 8 | 1396 | 6 | 479.0 | 27.0 | 0 | 0 | 0 | 0.0 | 493.5 | 27.0 | 0.0 | 1.0 | 1.0 | [6721] | 485.666667 | 408.888889 | 76.860750 | 1 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4990 | 2924 | 0 | 493.0 | 32.0 | 1 | 0 | 1 | 1.0 | 496.5 | 38.0 | 1.0 | 0.0 | 0.0 | [3004] | 425.000000 | 532.555556 | 108.985004 | 0 | 1 |
| 4991 | 313 | 0 | 484.0 | 35.0 | 1 | 0 | 0 | 2.0 | 480.5 | 34.0 | 1.0 | 0.0 | 1.0 | [7816] | 426.000000 | 510.777778 | 86.703675 | 0 | 1 |
| 4992 | 2919 | 0 | 490.5 | 20.0 | 0 | 1 | 1 | 2.0 | 483.5 | 18.0 | 0.0 | 1.0 | 0.0 | [2781] | 420.333333 | 517.555556 | 98.940997 | 0 | 1 |
| 4997 | 2926 | 0 | 486.0 | 45.0 | 0 | 0 | 0 | 2.0 | 479.5 | 46.0 | 0.0 | 1.0 | 1.0 | [971] | 421.777778 | 509.888889 | 89.948512 | 0 | 1 |
| 5001 | 1297 | 0 | 495.5 | 43.0 | 1 | 0 | 1 | 1.0 | 496.5 | 38.0 | 1.0 | 0.0 | 0.0 | [3004] | 415.111111 | 532.555556 | 118.405412 | 0 | 1 |
4964 rows × 19 columns
2.2 Matching with a fixed variable
[17]:
group_col = "industry"
Also group_col might be the list.
[18]:
model = Matcher(input_data=df, outcome=outcome, treatment=treatment,
info_col=info_col, group_col=group_col)
[18.12.2024 22:05:18 | hypex | INFO]: Number of NaN values filled with zeros: 446
[19]:
selected_features = model.feature_select()
selected_features
/Users/tikhomirov/PycharmProjects/Sber_New/LightAutoML/.venv/lib/python3.10/site-packages/hypex/selectors/feature_selector.py:42: UserWarning: FeatureSelector does not rule out the possibility of overlooked features, the complex impact of features on target description, or the significance of features from a business logic perspective.
warnings.warn(
[19]:
| rank | |
|---|---|
| signup_month | 1 |
| pre_spends | 2 |
| age | 3 |
| gender_F | 4 |
| gender_M | 5 |
[20]:
chosen_features = selected_features[:4].index
chosen_features
[20]:
Index(['signup_month', 'pre_spends', 'age', 'gender_F'], dtype='object')
[21]:
results, quality_results, df_matched = model.estimate(features=chosen_features)
[18.12.2024 22:05:22 | Faiss hypex | INFO]: The entry of bias into the ATT is 0.2%
[22]:
results
[22]:
| effect_size | std_err | p-val | ci_lower | ci_upper | outcome | |
|---|---|---|---|---|---|---|
| ATE | 81.192392 | 1.631028 | 0.0 | 77.995577 | 84.389208 | post_spends |
| ATC | 98.945415 | 3.116605 | 0.0 | 92.836869 | 105.053961 | post_spends |
| ATT | 63.425162 | 0.652349 | 0.0 | 62.146559 | 64.703765 | post_spends |
[23]:
df_matched[df_matched['industry'] != df_matched['industry_matched']]
[23]:
| index | signup_month | pre_spends | age | gender_F | gender_M | industry | signup_month_matched | pre_spends_matched | age_matched | gender_F_matched | gender_M_matched | industry_matched | index_matched | post_spends | post_spends_matched | post_spends_matched_bias | treat | treat_matched |
|---|
3. Results
3.1 ATE, ATT, ATC
[24]:
results
[24]:
| effect_size | std_err | p-val | ci_lower | ci_upper | outcome | |
|---|---|---|---|---|---|---|
| ATE | 81.192392 | 1.631028 | 0.0 | 77.995577 | 84.389208 | post_spends |
| ATC | 98.945415 | 3.116605 | 0.0 | 92.836869 | 105.053961 | post_spends |
| ATT | 63.425162 | 0.652349 | 0.0 | 62.146559 | 64.703765 | post_spends |
3.2 SMD, PSI, KS-test, repeats
[25]:
quality_results.keys()
[25]:
dict_keys(['psi', 'ks_test', 'smd', 'repeats'])
[26]:
quality_results['psi']
[26]:
| column_treated | anomaly_score_treated | check_result_treated | column_untreated | anomaly_score_untreated | check_result_untreated | |
|---|---|---|---|---|---|---|
| 0 | age_treated | 0.01 | OK | age_untreated | 0.08 | OK |
| 1 | gender_F_treated | 0.00 | OK | gender_F_untreated | 0.00 | OK |
| 2 | industry_treated | 0.00 | OK | industry_untreated | 0.00 | OK |
| 3 | pre_spends_treated | 0.61 | NOK | pre_spends_untreated | 0.16 | OK |
| 4 | signup_month_treated | 16.14 | NOK | signup_month_untreated | 0.00 | OK |
[27]:
quality_results['ks_test']
[27]:
| match_control_to_treat | match_treat_to_control | |
|---|---|---|
| age | 1.292517e-01 | 2.721265e-04 |
| pre_spends | 4.164710e-263 | 1.725086e-24 |
| signup_month | 0.000000e+00 | 0.000000e+00 |
[28]:
quality_results['repeats']
[28]:
{'match_control_to_treat': 0.42, 'match_treat_to_control': 0.07}
4. Save model
[29]:
model.save("test_model.pickle")
[30]:
model2 = Matcher.load("test_model.pickle")
[31]:
model2.results
[31]:
| effect_size | std_err | p-val | ci_lower | ci_upper | outcome | |
|---|---|---|---|---|---|---|
| ATE | 81.192392 | 1.631028 | 0.0 | 77.995577 | 84.389208 | post_spends |
| ATC | 98.945415 | 3.116605 | 0.0 | 92.836869 | 105.053961 | post_spends |
| ATT | 63.425162 | 0.652349 | 0.0 | 62.146559 | 64.703765 | post_spends |
[32]:
model.results
[32]:
| effect_size | std_err | p-val | ci_lower | ci_upper | outcome | |
|---|---|---|---|---|---|---|
| ATE | 81.192392 | 1.631028 | 0.0 | 77.995577 | 84.389208 | post_spends |
| ATC | 98.945415 | 3.116605 | 0.0 | 92.836869 | 105.053961 | post_spends |
| ATT | 63.425162 | 0.652349 | 0.0 | 62.146559 | 64.703765 | post_spends |
[ ]: