How to Use Dependency Discovery and Xi Pruning
By default, JPTs assume that every feature can influence every target
variable. In high-dimensional datasets this leads to unnecessarily
large trees, because the learning algorithm may chase spurious
relationships between unrelated variables. pyjpt provides two
mechanisms to address this: dependency discovery, which identifies
which feature–target pairs are genuinely related before tree
construction, and xi-based pruning, which stops splitting when
no statistically significant dependence remains in a partition.
Both mechanisms are built on Chatterjee’s \(\xi\) correlation coefficient [Cha21], a rank-based measure of functional dependence with several unique properties.
Mathematical Background
Chatterjee’s \(\xi\) coefficient
Given \(n\) paired observations \((X_1, Y_1), \ldots, (X_n, Y_n)\), sort the pairs by \(X\) so that \(X_{(1)} \leq \cdots \leq X_{(n)}\) and let \(r_{(i)}\) denote the rank of \(Y_{(i)}\) among all \(Y\) values. Then
Intuition: If \(Y\) is a function of \(X\), sorting by \(X\) also sorts \(Y\), so consecutive ranks differ by 1 and \(\xi_n \to 1\). If \(X\) and \(Y\) are independent, the ranks form a random permutation and \(\xi_n \to 0\).
Key properties:
\(\xi = 0\) if and only if \(X\) and \(Y\) are independent.
\(\xi = 1\) if and only if \(Y\) is a measurable function of \(X\).
Asymmetric: \(\xi(X, Y) \neq \xi(Y, X)\) in general. This is a feature: it measures how well \(Y\) can be predicted from \(X\), which is exactly the question a split on \(X\) answers.
Distribution-free: no assumptions about the distributions of \(X\) or \(Y\).
Detects any dependence: unlike Pearson (linear) or Spearman (monotonic), \(\xi\) detects arbitrary functional relationships, including periodic or many-to-one mappings.
\(\mathcal{O}(n \log n)\) computational complexity.
Significance test
Under the null hypothesis that \(X\) and \(Y\) are independent (with \(Y\) continuous), the asymptotic distribution of \(\xi\) is known [Cha21]:
This means that a z-test can be used to decide whether an observed \(\xi\) value is statistically significant. Given a significance level \(\alpha\), the null hypothesis of independence is rejected when
where \(z_{1-\alpha}\) is the \((1-\alpha)\)-quantile of the standard normal distribution. Setting \(\alpha = 0.05\) means: “only accept a dependence if there is less than a 5% chance it arose by coincidence.”
Dependency Discovery
Dependency discovery computes \(\xi\) for all feature–target pairs before tree construction and retains only those pairs where the relationship is statistically significant. This restricts the impurity computation during learning to genuine dependencies, preventing the tree from wasting splits on unrelated variables.
Basic usage
Pass an XiDependencyDiscovery
instance as the dependencies parameter:
from jpt.trees import JPT
from jpt.learning.dependency import (
XiDependencyDiscovery,
)
model = JPT(
variables,
targets=[target_var],
dependencies=XiDependencyDiscovery(
alpha=0.05
),
min_samples_leaf=0.01,
)
model.fit(data)
After learning, model.dependencies contains the discovered
dependency map. Inspect it to see which features were identified
as relevant:
for feat, targets in model.dependencies.items():
names = [t.name for t in targets]
print(f'{feat.name} -> {names}')
The alpha parameter controls the significance level:
Smaller \(\alpha\) (e.g. 0.01): stricter, fewer dependencies retained, more compact trees.
Larger \(\alpha\) (e.g. 0.10): more permissive, retains weaker relationships.
Persistence
The discovery strategy is preserved during serialization. When a
model is saved and loaded, calling fit() again will re-discover
dependencies from the new data:
model.save('model.json')
restored = JPT.load('model.json')
# Re-learning uses the same discovery strategy
restored.fit(new_data)
Backward compatibility
The dependencies parameter continues to accept None (fully
connected, the default) and explicit dictionaries:
# Default: every feature depends on every target
model = JPT(variables, dependencies=None)
# Manual: only X1 influences Y
model = JPT(
variables,
dependencies={x1_var: [y_var]}
)
# Automatic: discover from data
model = JPT(
variables,
dependencies=XiDependencyDiscovery(alpha=0.05)
)
Xi-Based Pruning
Even when the global dependency structure is known, the strength of a relationship may vary across subregions of the data. Xi-based pruning tests at each candidate split whether there is still significant functional dependence in the current partition. If not, the node becomes a leaf.
Using the pruning criterion
Pass an XiPruningCriterion
instance as the prune_or_split parameter of fit():
from jpt.learning.pruning import XiPruningCriterion
model = JPT(
variables,
targets=[target_var],
min_samples_leaf=0.01,
)
model.fit(
data,
prune_or_split=XiPruningCriterion(alpha=0.05),
)
The alpha parameter has the same interpretation as above: the
probability of a false split (splitting when there is no genuine
dependence).
The min_n parameter (default 30) sets the minimum partition size
for the test. Below this threshold, the xi test is not applied and
conventional stopping rules take effect. This handles the fact that
\(\xi\) requires a minimum sample size for reliable inference
(typically \(n \gtrsim 250\) for the asymptotic theory
[Cha21], but the test is still informative for
smaller \(n\) [DAG24]).
Combining Both
Dependency discovery and xi pruning are complementary: the former reduces the set of feature–target pairs globally, the latter adapts the stopping criterion locally. For maximum effect, use both:
from jpt.trees import JPT
from jpt.learning.dependency import (
XiDependencyDiscovery,
)
from jpt.learning.pruning import XiPruningCriterion
model = JPT(
variables,
targets=[target_var],
dependencies=XiDependencyDiscovery(
alpha=0.05
),
min_samples_leaf=0.01,
)
model.fit(
data,
prune_or_split=XiPruningCriterion(alpha=0.05),
)
print(f'Leaves: {len(model.leaves)}')
Worked Example
Consider a dataset with three variables: \(X_1\), \(X_2\) (both uniform noise), and \(Y = X_1^2 + \varepsilon\) where \(\varepsilon \sim \mathcal{N}(0, 1.5)\). By construction, \(Y\) depends on \(X_1\) but not on \(X_2\):
import numpy as np
from pandas import DataFrame
from jpt.distributions import Numeric
from jpt.variables import NumericVariable
from jpt.trees import JPT
from jpt.base.correlation import xi_correlation
from jpt.learning.dependency import (
XiDependencyDiscovery,
)
from jpt.learning.pruning import XiPruningCriterion
np.random.seed(42)
n = 2000
x1 = np.random.uniform(-2, 2, n)
x2 = np.random.uniform(-2, 2, n)
y = x1 ** 2 + np.random.normal(0, 1.5, n)
df = DataFrame({'X1': x1, 'X2': x2, 'Y': y})
vx1 = NumericVariable('X1', Numeric, precision=0.1)
vx2 = NumericVariable('X2', Numeric, precision=0.1)
vy = NumericVariable('Y', Numeric, precision=0.1)
# Check xi values
print(f'xi(X1, Y) = {xi_correlation(x1, y):.3f}')
print(f'xi(X2, Y) = {xi_correlation(x2, y):.3f}')
# Standard JPT
tree_std = JPT(
[vx1, vx2, vy],
targets=[vy],
min_samples_leaf=0.01
)
tree_std.fit(df)
# JPT with dependency discovery + xi pruning
tree_xi = JPT(
[vx1, vx2, vy],
targets=[vy],
dependencies=XiDependencyDiscovery(alpha=0.05),
min_samples_leaf=0.01
)
tree_xi.fit(
df,
prune_or_split=XiPruningCriterion(alpha=0.05)
)
print(f'Standard: {len(tree_std.leaves)} leaves')
print(f'Xi-aware: {len(tree_xi.leaves)} leaves')
Expected output:
xi(X1, Y) = 0.232
xi(X2, Y) = -0.009
Standard: 77 leaves
Xi-aware: 5 leaves
The standard tree produces 77 leaves, many from splits on the irrelevant variable \(X_2\). The xi-aware tree correctly identifies \(X_1\) as the only relevant feature and stops splitting once the signal in each partition is exhausted, yielding only 5 leaves.
Extending with Custom Discovery Strategies
The dependency discovery mechanism is extensible. To implement a
custom strategy, subclass
DependencyDiscovery
and implement three methods:
from jpt.learning.dependency.base import (
DependencyDiscovery,
)
class MyDiscovery(DependencyDiscovery):
def __init__(self, threshold=0.1):
self.threshold = threshold
def __call__(
self, data, features,
targets, variables
):
# Return a dict mapping each feature
# to its dependent targets
...
def to_json(self):
return {
'type': self.__class__.__name__,
'threshold': self.threshold,
}
@classmethod
def from_json(cls, data):
return cls(threshold=data['threshold'])
Subclasses are automatically registered for deserialization, so
DependencyDiscovery.from_json() will dispatch to the correct
class based on the 'type' key.
References
Sourav Chatterjee. A new coefficient of correlation. Journal of the American Statistical Association, 116(536):2009–2022, 2021. doi:10.1080/01621459.2020.1758115.
Christoph Dalitz, Lena Arning, and Steffen Goebbels. A simple bias reduction for Chatterjee's correlation. Journal of Statistical Theory and Practice, 18(2):1–15, 2024. doi:10.1007/s42519-024-00399-y.
Hongjian Shi, Mathias Drton, and Fang Han. On the power of Chatterjee's rank correlation. Biometrika, 109(2):317–333, 2022. doi:10.1093/biomet/asab028.