How to Use Dependency Discovery and Xi Pruning

By default, JPTs assume that every feature can influence every target variable. In high-dimensional datasets this leads to unnecessarily large trees, because the learning algorithm may chase spurious relationships between unrelated variables. pyjpt provides two mechanisms to address this: dependency discovery, which identifies which feature–target pairs are genuinely related before tree construction, and xi-based pruning, which stops splitting when no statistically significant dependence remains in a partition.

Both mechanisms are built on Chatterjee’s \(\xi\) correlation coefficient [Cha21], a rank-based measure of functional dependence with several unique properties.

Mathematical Background

Chatterjee’s \(\xi\) coefficient

Given \(n\) paired observations \((X_1, Y_1), \ldots, (X_n, Y_n)\), sort the pairs by \(X\) so that \(X_{(1)} \leq \cdots \leq X_{(n)}\) and let \(r_{(i)}\) denote the rank of \(Y_{(i)}\) among all \(Y\) values. Then

\[\xi_n(X, Y) = 1 - \frac{3 \sum_{i=1}^{n-1} |r_{(i+1)} - r_{(i)}|}{n^2 - 1}.\]

Intuition: If \(Y\) is a function of \(X\), sorting by \(X\) also sorts \(Y\), so consecutive ranks differ by 1 and \(\xi_n \to 1\). If \(X\) and \(Y\) are independent, the ranks form a random permutation and \(\xi_n \to 0\).

Key properties:

  • \(\xi = 0\) if and only if \(X\) and \(Y\) are independent.

  • \(\xi = 1\) if and only if \(Y\) is a measurable function of \(X\).

  • Asymmetric: \(\xi(X, Y) \neq \xi(Y, X)\) in general. This is a feature: it measures how well \(Y\) can be predicted from \(X\), which is exactly the question a split on \(X\) answers.

  • Distribution-free: no assumptions about the distributions of \(X\) or \(Y\).

  • Detects any dependence: unlike Pearson (linear) or Spearman (monotonic), \(\xi\) detects arbitrary functional relationships, including periodic or many-to-one mappings.

  • \(\mathcal{O}(n \log n)\) computational complexity.

Significance test

Under the null hypothesis that \(X\) and \(Y\) are independent (with \(Y\) continuous), the asymptotic distribution of \(\xi\) is known [Cha21]:

\[\sqrt{n} \, \xi_n \xrightarrow{d} \mathcal{N}\!\left(0, \tfrac{2}{5}\right).\]

This means that a z-test can be used to decide whether an observed \(\xi\) value is statistically significant. Given a significance level \(\alpha\), the null hypothesis of independence is rejected when

\[\frac{\sqrt{n} \, \xi_n}{\sqrt{2/5}} > z_{1-\alpha},\]

where \(z_{1-\alpha}\) is the \((1-\alpha)\)-quantile of the standard normal distribution. Setting \(\alpha = 0.05\) means: “only accept a dependence if there is less than a 5% chance it arose by coincidence.”

Dependency Discovery

Dependency discovery computes \(\xi\) for all feature–target pairs before tree construction and retains only those pairs where the relationship is statistically significant. This restricts the impurity computation during learning to genuine dependencies, preventing the tree from wasting splits on unrelated variables.

Basic usage

Pass an XiDependencyDiscovery instance as the dependencies parameter:

from jpt.trees import JPT
from jpt.learning.dependency import (
    XiDependencyDiscovery,
)

model = JPT(
    variables,
    targets=[target_var],
    dependencies=XiDependencyDiscovery(
        alpha=0.05
    ),
    min_samples_leaf=0.01,
)
model.fit(data)

After learning, model.dependencies contains the discovered dependency map. Inspect it to see which features were identified as relevant:

for feat, targets in model.dependencies.items():
    names = [t.name for t in targets]
    print(f'{feat.name} -> {names}')

The alpha parameter controls the significance level:

  • Smaller \(\alpha\) (e.g. 0.01): stricter, fewer dependencies retained, more compact trees.

  • Larger \(\alpha\) (e.g. 0.10): more permissive, retains weaker relationships.

Persistence

The discovery strategy is preserved during serialization. When a model is saved and loaded, calling fit() again will re-discover dependencies from the new data:

model.save('model.json')
restored = JPT.load('model.json')

# Re-learning uses the same discovery strategy
restored.fit(new_data)

Backward compatibility

The dependencies parameter continues to accept None (fully connected, the default) and explicit dictionaries:

# Default: every feature depends on every target
model = JPT(variables, dependencies=None)

# Manual: only X1 influences Y
model = JPT(
    variables,
    dependencies={x1_var: [y_var]}
)

# Automatic: discover from data
model = JPT(
    variables,
    dependencies=XiDependencyDiscovery(alpha=0.05)
)

Xi-Based Pruning

Even when the global dependency structure is known, the strength of a relationship may vary across subregions of the data. Xi-based pruning tests at each candidate split whether there is still significant functional dependence in the current partition. If not, the node becomes a leaf.

Using the pruning criterion

Pass an XiPruningCriterion instance as the prune_or_split parameter of fit():

from jpt.learning.pruning import XiPruningCriterion

model = JPT(
    variables,
    targets=[target_var],
    min_samples_leaf=0.01,
)
model.fit(
    data,
    prune_or_split=XiPruningCriterion(alpha=0.05),
)

The alpha parameter has the same interpretation as above: the probability of a false split (splitting when there is no genuine dependence).

The min_n parameter (default 30) sets the minimum partition size for the test. Below this threshold, the xi test is not applied and conventional stopping rules take effect. This handles the fact that \(\xi\) requires a minimum sample size for reliable inference (typically \(n \gtrsim 250\) for the asymptotic theory [Cha21], but the test is still informative for smaller \(n\) [DAG24]).

Combining Both

Dependency discovery and xi pruning are complementary: the former reduces the set of feature–target pairs globally, the latter adapts the stopping criterion locally. For maximum effect, use both:

from jpt.trees import JPT
from jpt.learning.dependency import (
    XiDependencyDiscovery,
)
from jpt.learning.pruning import XiPruningCriterion

model = JPT(
    variables,
    targets=[target_var],
    dependencies=XiDependencyDiscovery(
        alpha=0.05
    ),
    min_samples_leaf=0.01,
)
model.fit(
    data,
    prune_or_split=XiPruningCriterion(alpha=0.05),
)

print(f'Leaves: {len(model.leaves)}')

Worked Example

Consider a dataset with three variables: \(X_1\), \(X_2\) (both uniform noise), and \(Y = X_1^2 + \varepsilon\) where \(\varepsilon \sim \mathcal{N}(0, 1.5)\). By construction, \(Y\) depends on \(X_1\) but not on \(X_2\):

import numpy as np
from pandas import DataFrame
from jpt.distributions import Numeric
from jpt.variables import NumericVariable
from jpt.trees import JPT
from jpt.base.correlation import xi_correlation
from jpt.learning.dependency import (
    XiDependencyDiscovery,
)
from jpt.learning.pruning import XiPruningCriterion

np.random.seed(42)
n = 2000
x1 = np.random.uniform(-2, 2, n)
x2 = np.random.uniform(-2, 2, n)
y = x1 ** 2 + np.random.normal(0, 1.5, n)

df = DataFrame({'X1': x1, 'X2': x2, 'Y': y})
vx1 = NumericVariable('X1', Numeric, precision=0.1)
vx2 = NumericVariable('X2', Numeric, precision=0.1)
vy = NumericVariable('Y', Numeric, precision=0.1)

# Check xi values
print(f'xi(X1, Y) = {xi_correlation(x1, y):.3f}')
print(f'xi(X2, Y) = {xi_correlation(x2, y):.3f}')

# Standard JPT
tree_std = JPT(
    [vx1, vx2, vy],
    targets=[vy],
    min_samples_leaf=0.01
)
tree_std.fit(df)

# JPT with dependency discovery + xi pruning
tree_xi = JPT(
    [vx1, vx2, vy],
    targets=[vy],
    dependencies=XiDependencyDiscovery(alpha=0.05),
    min_samples_leaf=0.01
)
tree_xi.fit(
    df,
    prune_or_split=XiPruningCriterion(alpha=0.05)
)

print(f'Standard:  {len(tree_std.leaves)} leaves')
print(f'Xi-aware:  {len(tree_xi.leaves)} leaves')

Expected output:

xi(X1, Y) = 0.232
xi(X2, Y) = -0.009
Standard:  77 leaves
Xi-aware:  5 leaves

The standard tree produces 77 leaves, many from splits on the irrelevant variable \(X_2\). The xi-aware tree correctly identifies \(X_1\) as the only relevant feature and stops splitting once the signal in each partition is exhausted, yielding only 5 leaves.

Extending with Custom Discovery Strategies

The dependency discovery mechanism is extensible. To implement a custom strategy, subclass DependencyDiscovery and implement three methods:

from jpt.learning.dependency.base import (
    DependencyDiscovery,
)

class MyDiscovery(DependencyDiscovery):

    def __init__(self, threshold=0.1):
        self.threshold = threshold

    def __call__(
            self, data, features,
            targets, variables
    ):
        # Return a dict mapping each feature
        # to its dependent targets
        ...

    def to_json(self):
        return {
            'type': self.__class__.__name__,
            'threshold': self.threshold,
        }

    @classmethod
    def from_json(cls, data):
        return cls(threshold=data['threshold'])

Subclasses are automatically registered for deserialization, so DependencyDiscovery.from_json() will dispatch to the correct class based on the 'type' key.

References

[1] (1,2,3)

Sourav Chatterjee. A new coefficient of correlation. Journal of the American Statistical Association, 116(536):2009–2022, 2021. doi:10.1080/01621459.2020.1758115.

[2]

Christoph Dalitz, Lena Arning, and Steffen Goebbels. A simple bias reduction for Chatterjee's correlation. Journal of Statistical Theory and Practice, 18(2):1–15, 2024. doi:10.1007/s42519-024-00399-y.

[3]

Hongjian Shi, Mathias Drton, and Fang Han. On the power of Chatterjee's rank correlation. Biometrika, 109(2):317–333, 2022. doi:10.1093/biomet/asab028.