Unit 6 - Notes
INT394
Unit 6: Model Complexity and Optimization
1. Theoretical Foundations of Generalization
Understanding how machine learning models generalize from training data to unseen data requires quantifying model complexity.
1.1 VC Dimension (Vapnik-Chervonenkis Dimension)
The VC dimension is a measure of the capacity (complexity) of a statistical classification algorithm. It quantifies the algorithm's ability to "shatter" a set of points.
- Shattering: A hypothesis space shatters a set of points if, for all possible assignments of binary labels (0 or 1) to the points in , there exists a hypothesis that can perfectly separate the data.
- Definition: The VC dimension of a hypothesis space , denoted , is the size of the largest set of points that can be shattered by .
- Significance:
- If is finite, the model will generalize (given enough data).
- Low VC Dimension: The model is simple (e.g., a line). It may suffer from high bias (underfitting).
- High VC Dimension: The model is complex (e.g., high-degree polynomial). It requires significantly more data to train and is prone to high variance (overfitting).
- Examples:
- Linear Classifier in 2D: Can shatter 3 points. Cannot shatter 4 points (XOR problem). .
- Linear Classifier in d-dimensions: .
1.2 Rademacher Complexity
While VC dimension is distribution-independent (worst-case scenario), Rademacher Complexity measures the ability of a hypothesis class to fit random noise given a specific data distribution.
- Concept: It measures how well the function class correlates with random noise labels (Rademacher variables ).
- Empirical Rademacher Complexity:
- Interpretation:
- If a model implies a high Rademacher complexity, it can fit random noise easily, suggesting the model is too complex and likely to overfit.
- It provides tighter generalization bounds than VC dimension because it takes the actual data distribution into account.
1.3 Structural Risk Minimization (SRM)
SRM is an inductive principle for model selection used to balance model complexity and training error.
- Empirical Risk Minimization (ERM): Minimizes error on the training set only.
- Problem: Leads to overfitting (memorization).
- SRM Principle: Minimizes the True Risk bound, which is the sum of empirical risk and a complexity penalty.
- Where is a complexity term dependent on the VC dimension and sample size.
- Mechanism: SRM orders hypothesis spaces by complexity () and selects the hypothesis that minimizes the bound on the true risk, not just the empirical error.
2. Error Analysis: Bias, Variance, and Fitting
2.1 Bias-Variance Trade-off
The total error of a model can be decomposed into three parts: Bias, Variance, and Irreducible Error.
- Bias (Error from erroneous assumptions):
- The difference between the average prediction of the model and the correct value.
- High Bias: The model is too simple to capture the underlying pattern (e.g., linear regression on quadratic data).
- Variance (Error from sensitivity to fluctuations):
- The variability of model prediction for a given data point if the model is retrained on different realizations of the data.
- High Variance: The model captures random noise in the training data rather than the intended outputs.
- Irreducible Error (): Noise inherent in the problem itself; cannot be removed.
The Trade-off: Increasing model complexity generally decreases bias but increases variance. The optimal model minimizes the sum of bias squared and variance.
2.2 Overfitting vs. Underfitting
- Underfitting (High Bias):
- Symptoms: High training error and high validation error.
- Cause: Model is too simple; features are insufficient.
- Solution: Increase model complexity, add new features, reduce regularization.
- Overfitting (High Variance):
- Symptoms: Low training error but high validation error.
- Cause: Model is too complex; insufficient training data.
- Solution: Simplify model, gather more data, use regularization, early stopping.
3. Practical Model Improvement
3.1 Regularization Techniques
Regularization adds a penalty term to the loss function to constrain the magnitude of the model parameters (weights), effectively reducing complexity.
Loss Function with Regularization:
(Where is the regularization strength hyperparameter)
- L1 Regularization (Lasso Regression):
- Penalty: Sum of absolute values of weights ().
- Effect: Shrinks coefficients to exactly zero. Performs Feature Selection (sparse solutions).
- L2 Regularization (Ridge Regression):
- Penalty: Sum of squared weights ().
- Effect: Shrinks coefficients toward zero but rarely reaches exactly zero. Handles collinearity well and ensures weight matrices are invertible.
- Elastic Net:
- Combination of L1 and L2 penalties. Useful when there are correlated features (L2 handles grouping, L1 handles sparsity).
3.2 Cross-Validation
A resampling procedure used to evaluate machine learning models on a limited data sample.
- k-Fold Cross-Validation:
- Split data into equal subsets (folds).
- Train on folds, validate on the remaining 1 fold.
- Repeat times (rotating the validation fold).
- Result: Average of the scores. Balances bias and variance.
- Leave-One-Out Cross-Validation (LOOCV):
- (total number of data points).
- Train on all points except one.
- Pros: Unbiased estimate of test error.
- Cons: Computationally expensive; high variance.
- Stratified k-Fold:
- Ensures each fold represents the same class proportions as the whole dataset. Essential for imbalanced datasets.
3.3 Hyperparameter Tuning
Hyperparameters are configuration variables external to the model (e.g., learning rate, , in k-NN) that must be set before training.
- Grid Search:
- Define a grid of hyperparameter values.
- Exhaustively search through every combination.
- Con: Computationally prohibitive in high dimensions.
- Random Search:
- Sample hyperparameter combinations from statistical distributions.
- Pro: Often finds better models than grid search in less time because some hyperparameters matter more than others.
- Bayesian Optimization:
- Builds a probabilistic model of the function mapping hyperparameters to a target objective (accuracy).
- Intelligently selects the next hyperparameter set to evaluate based on past results (Exploration vs. Exploitation).
4. Gradient Descent and Variants
Gradient Descent (GD) is an iterative optimization algorithm used to minimize the cost function .
4.1 Basic Gradient Descent
Update Rule:
- : Learning rate (step size).
- : Gradient (direction of steepest ascent). We subtract to go downhill.
- Batch Gradient Descent: Uses the entire dataset to compute the gradient. Stable but slow for large datasets.
- Stochastic Gradient Descent (SGD): Uses a single training example for each update. Fast but noisy convergence (oscillates).
- Mini-Batch GD: Uses a small batch (e.g., 32 or 64 samples). Best balance of stability and speed.
4.2 Advanced Variants
Momentum
Addresses the problem where SGD oscillates across the slopes of a ravine while only making small progress along the bottom.
- Mechanism: Adds a fraction of the previous update vector to the current update. It gains speed in directions with persistent gradients.
- Formula:
(Where is the momentum term, typically 0.9)
RMSprop (Root Mean Square Propagation)
An adaptive learning rate method designed to resolve Adagrad's radically diminishing learning rates.
- Mechanism: Normalizes the gradient by the square root of a moving average of squared gradients. This allows the learning rate to adapt per parameter (slower for frequent features, faster for sparse ones).
- Formula:
- Compute squared gradient average:
- Update parameter:
4.3 Convergence Analysis
- Convexity:
- If the Loss function is Convex (bowl-shaped), Gradient Descent is guaranteed to converge to the Global Minimum (assuming is not too large).
- If Non-Convex (many peaks and valleys), it may get stuck in a Local Minimum or Saddle Point.
- Learning Rate ():
- Too Small: Convergence is extremely slow.
- Too Large: The algorithm may overshoot the minimum and diverge (Loss increases).
- Lipschitz Continuity:
- Convergence proofs often rely on the gradient being Lipschitz continuous, meaning the gradient does not change arbitrarily fast.
- Standard convergence rate for Convex functions using GD is , and for Strongly Convex functions is (geometric convergence).