Implementing Custom Regression Models
RegressionTables.jl is designed to be used with any custom model you create. The easiest way for this to work is to copy the StatsAPI.jl API, if that API is fully implemented then RegressionTables.jl will work out of the box.
This guide demonstrates how to create custom regression models that work seamlessly with RegressionTables.jl. When you have a regression model type that isn't natively supported, you can extend the package by implementing the required interface methods.
Overview
To make a custom model work with RegressionTables.jl, you need to implement several key components:
- Model Type Definition: Create a struct that subtypes
RegressionModel - StatsAPI Interface: Implement required methods for coefficient extraction and model statistics
- RegressionTables Integration: Add specific methods for table generation and formatting
- Optional Enhancements: Custom statistics, confidence intervals, and display options
Required Interface Methods
The core functions you need to implement are:
# Essential StatsAPI methods (minimum required)
StatsAPI.coef(model::YourModelType) # Coefficient estimates
# For coefficient names, implement ONE of:
StatsAPI.coefnames(model::YourModelType) # Variable names
# OR
StatsModels.formula(model::YourModelType) # Formula (if using formula interface)
# For response variable name, implement ONE of:
StatsAPI.responsename(model::YourModelType) # Response variable name
# OR
StatsModels.formula(model::YourModelType) # Formula (if using formula interface)
# For standard errors, implement ONE of:
StatsAPI.stderror(model::YourModelType) # Standard errors directly
# OR
StatsAPI.vcov(model::YourModelType) # Variance-covariance matrix (default implementation in StatsAPI)
# For model type identification, implement ONE of:
StatsAPI.islinear(model::YourModelType) # Returns true/false for linear models
# OR
RegressionTables.RegressionType(model::YourModelType) # Custom model type displayExample 1: Basic Implementation with Formula Interface
This example shows a straightforward implementation suitable for most linear regression use cases, including full formula support:
Step 1: Define the Model Structure
using StatsAPI, StatsModels, RegressionTables, RDatasets, Statistics
struct MyStatsModel <: RegressionModel
coef::Vector{Float64} # Coefficient estimates
vcov::Matrix{Float64} # Variance-covariance matrix
dof::Int # Degrees of freedom (total parameters)
dof_residual::Int # Residual degrees of freedom
nobs::Int # Number of observations
rss::Float64 # Residual sum of squares
tss::Float64 # Total sum of squares
coefnames::Vector{String} # Variable names
responsename::String # Dependent variable name
formula::FormulaTerm # Original formula
formula_schema::FormulaTerm # Processed formula with schema
endKey Points:
- The struct must subtype
RegressionModelto be recognized by RegressionTables - Store all necessary information for computing statistics
- Include formula information if you want formula-based fitting
Step 2: Implement the StatsAPI Interface
# Methods to match the StatsAPI interface
StatsAPI.coef(m::MyStatsModel) = m.coef
# Coefficient and response names (using stored values)
StatsAPI.coefnames(m::MyStatsModel) = m.coefnames
StatsAPI.responsename(m::MyStatsModel) = m.responsename
# Standard errors (via variance-covariance matrix)
StatsAPI.vcov(m::MyStatsModel) = m.vcov
# Model type identification
StatsAPI.islinear(m::MyStatsModel) = true
# Additional useful methods for regression statistics
StatsAPI.nobs(m::MyStatsModel) = m.nobs
StatsAPI.dof(m::MyStatsModel) = m.dof
StatsAPI.dof_residual(m::MyStatsModel) = m.dof_residual
StatsAPI.rss(m::MyStatsModel) = m.rss
StatsAPI.nulldeviance(m::MyStatsModel) = m.tss
StatsAPI.deviance(m::MyStatsModel) = StatsAPI.rss(m)
StatsAPI.mss(m::MyStatsModel) = nulldeviance(m) - StatsAPI.rss(m)
StatsAPI.r2(m::MyStatsModel) = StatsAPI.r2(m, :devianceratio)
# Formula support (enables formula-based fitting), formula takes
# precedence over `coefnames` and `reponsename`
StatsModels.formula(m::MyStatsModel) = m.formula_schemaKey Points:
coef()is the absolute minimum requirementcoefnames()andresponsename()provide variable names for displayvcov()enables automatic standard error calculationislinear(true)indicates this is a linear model- R² calculation uses the deviance ratio method (standard for linear models)
Step 3: Implement the Fitting Function
function StatsAPI.fit(::Type{MyStatsModel}, f::FormulaTerm, df::DataFrame)
# Data preparation
df = dropmissing(df)
f_schema = apply_schema(f, schema(f, df))
y, X = modelcols(f_schema, df)
response_name, coefnames_exo = coefnames(f_schema)
n, p = size(X)
# Ordinary least squares estimation
β = X \ y
ŷ = X * β
res = y - ŷ
# Calculate statistics
rss = sum(abs2, res)
tss = sum(abs2, y .- mean(y))
dof = p
dof_residual = n - p
vcov = inv(X'X) * rss / dof_residual
# Create and return model
MyStatsModel(β, vcov, dof, dof_residual, n, rss, tss,
coefnames_exo, response_name, f, f_schema)
endKey Points:
- Uses StatsModels.jl for formula processing and data preparation
- Implements standard OLS estimation
- Computes variance-covariance matrix for standard errors
- Returns a fully populated model struct
Step 4: Usage and Output
# Example usage
df = dataset("plm", "Cigar")
rr1 = StatsAPI.fit(MyStatsModel, @formula(Sales ~ 1 + Price + NDI), df)
rr2 = StatsAPI.fit(MyStatsModel, @formula(Sales ~ 1 + Price), df)
rr3 = StatsAPI.fit(MyStatsModel, @formula(Sales ~ 1 + NDI), df)
regtable(rr1, rr2, rr3)This produces:
--------------------------------------------------
Sales
------------------------------------
(1) (2) (3)
--------------------------------------------------
(Intercept) 138.480*** 139.734*** 132.981***
(1.427) (1.521) (1.538)
Price -0.938*** -0.230***
(0.054) (0.019)
NDI 0.007*** -0.001***
(0.000) (0.000)
--------------------------------------------------
N 1,380 1,380 1,380
R2 0.209 0.097 0.034
--------------------------------------------------Example 2: Advanced Implementation with Custom Features
This example demonstrates a more sophisticated implementation with numerical optimization, custom statistics, confidence intervals, and specialized display options:
Step 1: Define the Advanced Model Structure
using Optim, ForwardDiff, NamedArrays, HypothesisTests, RegressionTables, StatsAPI, NLSolversBase, LinearAlgebra
n = 40 # Number of observations
nvar = 2 # Number of variables
β = ones(nvar) * 3.0 # True coefficients
x = [ 1.0 0.156651 # X matrix of explanatory variables plus constant
1.0 -1.34218
1.0 0.238262
1.0 -0.496572
1.0 1.19352
1.0 0.300229
1.0 0.409127
1.0 -0.88967
1.0 -0.326052
1.0 -1.74367
1.0 -0.528113
1.0 1.42612
1.0 -1.08846
1.0 -0.00972169
1.0 -0.85543
1.0 1.0301
1.0 1.67595
1.0 -0.152156
1.0 0.26666
1.0 -0.668618
1.0 -0.36883
1.0 -0.301392
1.0 0.0667779
1.0 -0.508801
1.0 -0.352346
1.0 0.288688
1.0 -0.240577
1.0 -0.997697
1.0 -0.362264
1.0 0.999308
1.0 -1.28574
1.0 -1.91253
1.0 0.825156
1.0 -0.136191
1.0 1.79925
1.0 -1.10438
1.0 0.108481
1.0 0.847916
1.0 0.594971
1.0 0.427909]
ε = [ 0.5539830489065279 # Errors
-0.7981494315544392
0.12994853889935182
0.23315434715658184
-0.1959788033050691
-0.644463980478783
-0.04055657880388486
-0.33313251280917094
-0.315407370840677
0.32273952815870866
0.56790436131181
0.4189982390480762
-0.0399623088796998
-0.2900421677961449
-0.21938513655749814
-0.2521429229103657
0.0006247891825243118
-0.694977951759846
-0.24108791530910414
0.1919989647431539
0.15632862280544485
-0.16928298502504732
0.08912288359190582
0.0037707641031662006
-0.016111044809837466
0.01852191562589722
-0.762541135294584
-0.7204431774719634
-0.04394527523005201
-0.11956323865320413
-0.6713329013627437
-0.2339928433338628
-0.6200532213195297
-0.6192380993792371
0.08834918731846135
-0.5099307915921438
0.41527207925609494
-0.7130133329859893
-0.531213372742777
-0.09029672309221337]
y = x * β + ε; # Generate Data
struct CustomModel <: StatsAPI.RegressionModel
coef::Vector{Float64} # Parameter estimates
vcov::Matrix{Float64} # Variance-covariance matrix
dof_residual::Int # Residual degrees of freedom
nobs::Int # Number of observations
coefnames::Vector{String} # Parameter names
responsename::String # Dependent variable name
LogLikelihood::Float64 # Log-likelihood value
BIC::Float64 # Bayesian Information Criterion
BS::Float64 # Jarque-Bera test statistic
civec::Matrix{Float64} # Confidence intervals matrix
endKey Points:
- Stores additional statistics like log-likelihood and BIC
- Includes diagnostic test results (Jarque-Bera)
- Pre-computes confidence intervals for efficient display
Step 2: Define Custom Regression Statistics
# Custom statistic for displaying the test
struct MyStatistic <: RegressionTables.AbstractRegressionStatistic
val::Union{Float64, Nothing}
end
# Constructor for the custom statistic
MyStatistic(m::CustomModel) = MyStatistic(m.BS)
MyStatistic(m::StatsAPI.RegressionModel) = MyStatistic(nothing) # Default for other models
# Label for display in tables
RegressionTables.label(render::AbstractRenderType, x::Type{MyStatistic}) = "Normality"Key Points:
- Custom statistics must subtype
AbstractRegressionStatistic - Provide constructors for your model type and a fallback for others
- Define how the statistic should be labeled in tables
Step 3: Implement Core Interface Methods
# Essential StatsAPI methods (minimum required)
StatsAPI.coef(m::CustomModel) = m.coef
# Coefficient and response names
StatsAPI.coefnames(m::CustomModel) = m.coefnames
StatsAPI.responsename(m::CustomModel) = m.responsename
# Standard errors (via variance-covariance matrix)
StatsAPI.vcov(m::CustomModel) = m.vcov
StatsAPI.dof_residual(m::CustomModel) = m.dof_residual
StatsAPI.nobs(m::CustomModel) = m.nobs
# Model fit statistics
StatsAPI.loglikelihood(m::CustomModel) = m.LogLikelihood
StatsAPI.bic(m::CustomModel) = m.BIC
# Confidence intervals
StatsAPI.confint(m::CustomModel; level::Real = 0.95) = m.civec
# Model type identification (using RegressionType instead of islinear)
RegressionTables.RegressionType(m::CustomModel) = RegressionTables.RegressionType("My type")
# Define default statistics for this model type
RegressionTables.default_regression_statistics(m::CustomModel) = [Nobs, MyStatistic, LogLikelihood, BIC]
# Remove significance stars for this model
RegressionTables.default_symbol(render::AbstractRenderType) = ""Key Points:
- Implements the same essential methods as Example 1
- Uses
RegressionType()instead ofislinear()for custom model identification default_regression_statistics()defines which statistics appear by default- Can customize display elements like significance stars
Step 4: Custom Confidence Interval Formatting
function Base.repr(render::AbstractRenderType, x::ConfInt;
digits=RegressionTables.default_digits(render, x), args...)
if RegressionTables.value(x) == (0, 0)
repr(render, "restricted")
else
RegressionTables.below_decoration(render,
repr(render, RegressionTables.value(x)[1]; digits) * ", " *
repr(render, RegressionTables.value(x)[2]; digits))
end
endKey Points:
- Customizes how confidence intervals are displayed
- Handles special case of restricted parameters
- Uses the package's formatting conventions
Step 5: Advanced Fitting with Optimization
function StatsAPI.fit(::Type{CustomModel}, X, Y, s_v, restricted_parameters)
# Define log-likelihood function
function Log_Likelihood(X, Y, params::AbstractVector{T}, restricted_parameters) where T
full_parameter_vector = vcat(params, restricted_parameters)
n = size(X, 1)
σ = exp(full_parameter_vector["Sigma"])
llike = -n/2*log(2π) - n/2*log(σ^2) -
(sum((Y - X[:,1] * full_parameter_vector["Beta 1"] -
X[:,2] * full_parameter_vector["Beta 2"]).^2) / (2σ^2))
return -llike # Return negative for minimization
end
# Numerical optimization
func = TwiceDifferentiable(vars -> Log_Likelihood(X, Y, vars, restricted_parameters),
s_v; autodiff=:forward)
opt = optimize(func, s_v)
parameters = Optim.minimizer(opt)
# Compute standard errors from Hessian
var_cov_matrix = inv(hessian!(func, parameters))
padding = zeros(length(parameters), length(restricted_parameters))
augmented_var_cov_matrix = [var_cov_matrix padding;
padding' I(length(restricted_parameters))]
# Calculate additional statistics
ll = -opt.minimum
n = size(X, 1)
names_param = names(vcat(parameters, restricted_parameters), 1)
# Confidence intervals with transformations
left_end = parameters - 1.96 * sqrt.(diag(var_cov_matrix))
right_end = parameters + 1.96 * sqrt.(diag(var_cov_matrix))
left_end["Sigma"] = exp(left_end["Sigma"])
right_end["Sigma"] = exp(right_end["Sigma"])
parameters["Sigma"] = exp(parameters["Sigma"])
# Diagnostic tests
all_parameters = vcat(parameters, restricted_parameters)
residuals = Y - X * all_parameters[["Beta 1", "Beta 2"]]
BIC = -2 * ll + log(n) * length(parameters)
BS = JarqueBeraTest(residuals; adjusted=true).JB
civec = vcat(hcat(left_end.array, right_end.array), [0 0])
CustomModel(all_parameters.array, augmented_var_cov_matrix.array,
n - length(s_v), n, names_param, "Outcome", ll, BIC, BS, civec)
endKey Points:
- Uses maximum likelihood estimation with numerical optimization
- Automatic differentiation for computing Hessian matrix
- Handles parameter transformations (e.g., log-sigma)
- Includes diagnostic testing (Jarque-Bera for normality)
- Constructs confidence intervals accounting for transformations
Step 6: Usage and Advanced Output
# Set up optimization parameters
s_v = NamedArray(Array{Float64}(undef, 2))
setnames!(s_v, ["Beta 2", "Sigma"], 1)
s_v["Beta 2"] = 1.0
s_v["Sigma"] = 0.05
# Set up restricted parameters
s_v_r = NamedArray(Array{Float64}(undef, 1))
setnames!(s_v_r, ["Beta 1"], 1)
s_v_r["Beta 1"] = 3.0
# Fit model and generate table with confidence intervals
rr1 = StatsAPI.fit(CustomModel, x, y, s_v, s_v_r)
regtable(rr1, below_statistic = ConfInt)This produces:
-------------------------------
Outcome
-------------------------------
Beta 2 3.069
(2.926, 3.212)
Sigma 0.406
(0.326, 0.506)
Beta 1 3.000
restricted
-------------------------------
N 40
Normality 1.115
Log Likelihood -20.722
BIC 48.823
-------------------------------Implementation Guidelines
1. Choose Your Approach
- Example 1 is ideal for linear models with standard statistics
- Example 2 is better for complex models requiring numerical methods
2. Essential Methods (Minimum Requirements)
Always implement these core methods:
StatsAPI.coef()- coefficient estimates- One of:
StatsAPI.coefnames()ORStatsModels.formula() - One of:
StatsAPI.responsename()ORStatsModels.formula() - One of:
StatsAPI.stderror()ORStatsAPI.vcov() - One of:
StatsAPI.islinear()ORRegressionTables.RegressionType()
3. Optional Enhancements
StatsAPI.confint()- enables confidence interval display- Custom statistics via
AbstractRegressionStatistic - Custom formatting via
Base.repr()methods - Model-specific defaults via
default_regression_statistics()
4. Testing Your Implementation
# Verify basic functionality
model = fit(YourModel, data...)
@assert length(coef(model)) == length(coefnames(model))
# Test table generation
table = regtable(model)5. Advanced Customization
- Override
default_regression_statistics()to change default statistics - Use
RegressionType()to customize model identification - Implement
other_stats()for completely custom table sections