Introduction
What are JPTs?
Joint Probability Trees (JPTs) are a non-parametric probabilistic model that learns and represents the joint distribution \(P(\mathcal{X})\) over a set of variables \(\mathcal{X}\) directly from data.
A JPT partitions the data space into a set of regions using a decision tree. In each leaf region the distribution over all variables is approximated by a fully factorised product of univariate distributions. The overall joint distribution is a mixture across all leaves:
where \(\Lambda\) is the set of leaves and \(P(L=\lambda)\) is the prior probability (mixing weight) of leaf \(\lambda\).
Variable Types
pyjpt natively handles three types of variables in a single model:
Variable type |
Data type |
Leaf distribution |
|---|---|---|
|
|
Quantile-based (piecewise linear CDF) |
|
|
Multinomial |
|
integer domain |
Discrete uniform / histogram |
Use infer_from_dataframe() to create the right
variable type automatically from a DataFrame’s column dtypes.
Why JPTs?
No structural assumptions — the tree partition is learned from data; no prior knowledge about dependencies or distribution families is required.
Hybrid support — symbolic and continuous variables coexist in a single model without manual encoding.
Tractable inference — all query types (marginal, conditional, posterior, MPE) are computed in a single pass over the tree.
White-box — every inference result traces back to specific leaves, enabling interpretable explanations.
Linear complexity — training and inference both scale linearly in the number of leaves.
Supported Inference Types
Query |
API method |
|---|---|
Marginal \(P(Q)\) |
|
Conditional \(P(Q \mid E)\) |
|
Posterior distribution |
|
Expectation |
|
Most Probable Explanation (MPE) |
|
k-MPE |
|
Conditional JPT |
Theory
Probabilistic Circuits
JPTs are a shallow, deterministic probabilistic circuit (PC). A JPT defines a tree-like computational graph: deterministic sum nodes all the way down to the leaves, where fully factorising product nodes are used. For more background on probabilistic circuits see [CVVdB20].
The sum nodes are decision nodes like in decision trees. They contain one variable and a split value that partitions the data into two subsets. The product nodes fully factorise the problem into independent distributions represented by quantile distributions for numeric variables and multinomials for symbolic variables.
Marginal and Conditional Queries
A marginal query (MAR) is a partial assignment:
where \(Z = \mathcal{X} \setminus \mathcal{E}\) are the unassigned variables.
A conditional query follows from two marginal queries:
Most Probable Explanation
MPE (a.k.a. MAP) solves:
JPTs return a set of results since the piecewise structure allows multiple maxima to exist and maxima can be intervals rather than single points.
Probabilistic Learning
Generative Learning
In generative mode (the default), the tree is built by a modified C4.5 algorithm that maximises information gain across all variables. Each leaf stores a fully factorised product distribution. This mode models the full joint \(P(\mathcal{X})\).
Discriminative Learning
In discriminative mode, the impurity computation is restricted to a designated set of target variables \(Y\). Splits are chosen to maximise information gain with respect to \(Y\), while features \(X = \mathcal{X} \setminus Y\) serve solely as split candidates. This concentrates the model’s capacity on predicting \(Y\) and is well-suited for classification and regression.
Activate discriminative mode via the targets argument:
model = JPT(variables, targets=[varnames['species']])
model.fit(df)