from functools import partial

import numpy as np
from scipy.spatial.distance import pdist

from frouros.callbacks import PermutationTestDistanceBased
from frouros.detectors.data_drift import MMD
from frouros.utils import load, save
from frouros.utils.kernels import rbf_kernel

Save and Load detector#

In this example, we will demonstrate how to save and load a detector. We will use the MMD detector and the permutation test callback. We will first fit the detector and then compare two datasets. We will then save the detector to a file and load it back. We will then compare the same two datasets and assert that the distance and p-value are the same before and after saving and loading the detector.

Set random seed#

We will set the random seed to ensure reproducibility.

seed = 31
np.random.seed(seed)

Generate data#

We will generate two datasets. The first dataset will be generated from a multivariate normal distribution with mean [0, 0] and covariance matrix [[1, 0], [0, 1]]. The second dataset will be generated from a multivariate normal distribution with mean [1, 0] and covariance matrix [[1, 0], [0, 2]].

num_samples = 100

x_mean = [0, 0]
x_cov = [
    [1, 0],
    [0, 1],
]

y_mean = [1, 0]
y_cov = [
    [1, 0],
    [0, 2],
]

X_ref = np.random.multivariate_normal(
    mean=x_mean,
    cov=x_cov,
    size=num_samples,
)
X_test = np.random.multivariate_normal(
    mean=y_mean,
    cov=y_cov,
    size=num_samples,
)

Fit detector#

We will fit the detector using the reference dataset.

sigma = np.median(
    pdist(
        X=X_ref,
        metric="euclidean",
    ),
)
sigma
1.5941478725484344
detector = MMD(
    kernel=partial(
        rbf_kernel,
        sigma=sigma,
    ),
    callbacks=PermutationTestDistanceBased(
        num_permutations=100,
        num_jobs=-1,
        method="exact",
        random_state=seed,
        name="permutation_test",
    ),
)

_ = detector.fit(
    X=X_ref,
)

Compare datasets before saving#

We will compare the reference and test datasets.

distance, callback_logs = detector.compare(
    X=X_test,
)
before_save_distance = distance.distance
before_save_p_value = callback_logs["permutation_test"]["p_value"]
print(f"Distance: {before_save_distance:.8f}, p-value: {before_save_p_value:.8f}")
Distance: 0.14644993, p-value: 0.00990049

Save and Load detector#

We will save the detector to a file and load it back.

save(
    obj=detector,
    filename="detector.pkl",
)

detector = load(
    filename="detector.pkl",
)

Compare datasets after loading#

We will compare the reference and test datasets again.

distance, callback_logs = detector.compare(
    X=X_test,
)
after_save_distance = distance.distance
after_save_p_value = callback_logs["permutation_test"]["p_value"]
print(f"Distance: {after_save_distance:.8f}, p-value: {after_save_p_value:.8f}")
Distance: 0.14644993, p-value: 0.00990049

Assert that the distance and p-value are the same before and after saving and loading the detector.

assert before_save_distance == after_save_distance
assert before_save_p_value == after_save_p_value