Quantile Loss in Neural Networks

Implementation, results tendency, and handling

Shiro Matsumoto
12 min readDec 29, 2023

The usefulness of quantile loss in demand forecasting

When building a prediction model using machine learning in business, there are situations where you do not simply want to know “the predicted value (point prediction)” but rather “how likely the predicted value is to fall within a certain range (interval prediction).” The best example of this is demand forecasting. When demand forecasting is required, business needs cannot be met simply by forecasting the quantity most likely to sell. It is necessary to predict how much will sell when demand swings upwards and prepare inventory to meet that demand. For example, if you stock only the most likely amount of the demand forecast, you will be out of stock (roughly) one out of every two times and miss the opportunity to make a sale. If you have an inventory at the 95th percentile of the forecast (the value at which there is a 95% probability that demand will be less than or equal to this value), you can reduce the number of shortages to approximately once in 20 times.

Machine learning methods for obtaining such percentile values include quantile regression and various GBDT libraries.

Quantile Regression:

GBDT libraries:

This article explains how to apply quantile loss on neural networks by defining it as a custom loss function and documenting its use.

Why does a quantile loss provide a quantile?

Before getting to the main topic, a simple example will be given of why quantile loss gives a quantile point. Those who understand these points may skip this section. Rather than rigorous proof, the discussion here will be intuitive, restricting to a few examples. The least squares method finds a for the data x₁, …, xₙ that minimizes the squared loss function Σ(xᵢ — a)². In contrast, the quantile loss defines a check function.

And uses it to find a way that minimizes the quantile loss.

First, I show that a, which minimizes quantile loss when τ=0.5, is the median value of x₁, …, xₙ. The following can be obtained from the definition.

The absolute function cannot be differentiated at the point where it becomes 0, but by ignoring the possibility that |xᵢ — a|=0, it can be differentiated at other points by dividing the case as follows.

From this, the result of differentiating quantile loss by a is as follows.

Since the point of minimum quantile loss is the point at which the derivative is zero, the point at which “the number of xᵢ with xᵢ<a” equals “the number of xᵢ with xᵢ>a,” i.e., the median, would be the minimum point of quantile loss.

Next, let us look at the case where τ = 0.25 in the same way. The following can be obtained from the definition.

Similarly, the differentiation can be obtained by separating the cases, and the following equation can be obtained in the same way as above.

The derivative of quantile point loss will be zero when the value of a satisfies “the number of xᵢ with xᵢ<a”: “the number of xᵢ with xᵢ>a” = 1:3. This means that a is the first quartile point (25th percentile point) of x₁, …, xₙ.

Implementation of quantile loss with Pytorch

Here is an example of defining quantile loss as a custom loss function using Pytorch.

import torch
# Define quantile loss function
def quantile_loss(preds, target, quantile):
assert 0 < quantile < 1, "Quantile should be in (0, 1) range"
errors = target - preds
loss = torch.max((quantile - 1) * errors, quantile * errors)
return torch.abs(loss).mean()

An example of training with this custom loss function is shown below.

# Train model
for epoch in range(num_epochs):
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
outputs = model(batch_x)
loss = quantile_loss(outputs, batch_y, quantile)
loss.backward()
optimizer.step()

Let’s see if this defined custom loss function works as expected.

Execution of quantile loss with Pytorch

First, try generating a uniform random distribution (-5~5) for x and a normal random distribution for y proportional to the exponent of x, and see if you can predict the quantile point of y from x.

# Generate dummy data
num_samples = 10000
shape = (num_samples, 1)
torch.manual_seed(0)

# x is uniform random from -5 to 5
# y is random normal distribution * exp(scaled x)
x_tensor = torch.rand(shape) * 10 - 5
x_scaled = x_tensor / 5
y_tensor = torch.randn(shape) * torch.exp(x_scaled)

# Convert values to NumPy array (for graphs)
x = x_tensor.numpy()
y = y_tensor.numpy()

The network structure is simple: 64 nodes in the two intermediate layers + ReLUs in each layer. 100 epochs are used without any regularization or early-stoppping. The quartiles (percentile values) to be predicted are [0.500, 0.700, 0.950, 0.990, 0.995] in columns and [1, 4, 16, 64, 256] in rows for the batch size, for a total of 25 predictions. The ratio of instances below the predicted output value (red) out of 10,000 training data instances (blue) is noted in the figure as the “Actual” value.

Image by author

The percentage of samples below the specified percentile value is generally close to the specified value, and the shape of the output quantile predictions is straightforward. If we can get results like this, we will likely get what we want.

Next, consider a slightly complex example where y=clip(x, -2, 2) + randn. Where clip(x, -2, 2) is the clip function (that restricts a value to a specified range. When a number is outside the given range, the function “clips” it to the nearest bounder; if you set the range as -2 to 2 and provide an input value of -5, the function will return -2; if you provide 10, it will return 2), and randn is a random number that follows a normal distribution. The network structure and other settings are the same as in the previous case.

Image by author

As in the previous case, the percentages of the samples below the specified percentile value are generally close to the given value. However, the lower right of the 5x5 figure, the closer the given percentile value is to 1; the larger the batch size, the more the shape of the quantile predictions deviates from the shape of the training sample. The desirable shape of quantile predictions is always the shape of the red line in the upper left figure. It should move parallel upward as the specified percentile increases. However, as we move to the lower right of the figure, the red line of predictions has taken on a more linear shape, which is not a preferable result. The tendency to produce flat forecasts that do not follow demand fluctuations is hereafter referred to as flatten, flattened, or flattening. It should be noted that even in the bottom right figure, the percentage of the samples below the specified percentile value is 0.996, which is close to the given value of 0.995.

Let’s check it with a more complex shape for a better understanding. Here, we target y=2sin(x) + randn. Other settings are identical as in the preceding case.

Image by author

Again, the percentages of samples below the specified percentile values are generally close to the specified values. However, the shape of the quantile predictions deviates from the sinusoidal shape as one moves toward the lower right of the 5x5 figure. The red line of the predicted values becomes more linear in the lower right of the figure.

Remark

If the above demand fluctuations were weekly cycles, and if the product procurement cycle were also weekly, the flattening of the above figure would cause no problem at first glance. However, when the input data is not toy data but high-dimensional data used in practice, as in this case, it is difficult to predict in advance the nature of the irreducible errors (errors that occur even when the model is sufficiently generalized), and this flattening tendency is treated as undesirable.

Detection, avoidance, and mitigation methods

A summary of what we have seen so far is

  • In neural networks, Quantile Loss can be defined as a custom loss function, which can be trained to minimize Quantile Loss.
  • If the set percentile values are close to 0 or 1, the training results do not follow the trend of the training data and are relatively flat.
  • This “flat training result” is noticeable when the batch size is large and is not seen when the batch size is small.

In these toy data, input x was one-dimensional, and the relationship between input x and output y was clear in advance, so it was possible to determine whether the results obtained were flattened or not. Then, when the input x is high-dimensional, and the relationship between input x and output y is unknown, how can we determine whether the results obtained using Quantile Loss are “flattened” or not, and how can we “avoid flattening”?

Detection Methods for “Flattening”

One of the approaches to detect “flattening” is to calculate the 50th, 68th, and 95th percentile values together and check the relationship between these values, even if the final value to be obtained is the 99.5th percentile value. If the sample distribution follows a normal distribution with μ as the mean and σ as the standard deviation, the percentile values corresponding to 1σ, 2σ, and 3σ are obtained as follows, respectively.

Image by author

Using this, if the ratios of the (84.13 percentile value - 50.00 percentile value), (97.72 percentile value - 50.00 percentile value), and (99.87 percentile value - 50.00 percentile value) deviate significantly from 1:2:3, we can determine that the deviating percentile values have flattened.

Avoidance method for Flattening

The first approach of avoidance is to reduce the batch size, as seen in the experiment above. Large batch size will produce results where the ratio of samples below and above the prediction within one batch will balance to a specified percentile value, even with flat predictions. Smaller batch sizes avoid this problem and are less likely to produce flat predictions. On the other hand, reducing the batch size has disadvantages, such as unstable convergence and increased training time, making this option difficult to adopt.

The second approach of avoidance is to collect similar samples in the same batch instead of generating batches randomly. This avoids “balancing the ratio of samples below and above the predicted value within a batch to a specified percentile value. However, in the case of this toy sample, it is easy to collect similar samples because x is one-dimensional, but when x is a higher dimension, it is computationally expensive to define what a similar sample is appropriately and to collect similar samples.

Mitigation method for “Flattening”

On the other hand, a flattening mitigation method can be considered an extension of the flattening detection method. The following symbols are used in the following equation.

  • p0: 50.00 percentile value
  • p1: 84.13 percentile value
  • p2: 97.72 percentile value
  • p3: 99.87 percentile value

Using the above variables, an appropriate 99.87 percentile value can be obtained using the following flowchart.

Image by author

Simple Implementation for multiple quantile estimates

In the previous example code for the custom loss function, a single percentile value was given, and the quantile loss was returned, but quantile value estimations corresponding to multiple percentile values are needed to implement detection and mitigation methods. To handle this, a single neural network is used to estimate multiple percentile values simultaneously, and the custom loss function is also changed to a function that returns the sum of quantile losses corresponding to the given multiple percentile values.

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

def multi_quantile_loss(preds, target, quantiles):
assert isinstance(preds, torch.Tensor), "Predictions must be a torch.Tensor"
assert isinstance(target, torch.Tensor), "Target must be a torch.Tensor"
assert isinstance(quantiles, (list, torch.Tensor)), "Quantiles must be a list or torch.Tensor"
assert len(preds.shape) == 2, "Predictions must have 2 dimensions (batch_size, num_quantiles)"
assert preds.shape[1] == len(quantiles), f"Number of predictions ({preds.shape[1]}) must match the number of quantiles ({len(quantiles)})"
assert preds.shape == target.shape, "Shape of predictions must match shape of target"

if isinstance(quantiles, list):
assert all(0 < q < 1 for q in quantiles), "Quantiles should be in (0, 1) range"
else:
assert torch.all((0 < quantiles) & (quantiles < 1)), "Quantiles should be in (0, 1) range"

# Convert quantiles to a tensor if it's a list
if isinstance(quantiles, list):
quantiles_tensor = torch.tensor(quantiles, device=preds.device).view(1, -1)
else:
quantiles_tensor = quantiles.view(1, -1)

# Calculate errors
errors = target - preds

# Calculate losses for each quantile
losses = torch.max((quantiles_tensor - 1) * errors, quantiles_tensor * errors)

# Sum the losses and take the mean
loss = torch.mean(torch.sum(losses, dim=1))

return loss

The structure of the neural network is also changed so that the number of nodes in the output layer can be specified.

# Define a simple neural network architecture
class QuantileNet(nn.Module):
def __init__(self, output_size):
super(QuantileNet, self).__init__()
self.fc1 = nn.Linear(1, 64) # Assuming input features are 1-dimensional
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(64, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, output_size) # Output layer with output_size nodes

def forward(self, x):
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x

The code for generating dummy data is almost the same.

# const
NUM_SAMPLES = 10000
SHAPE = (NUM_SAMPLES, 1)
QUANTILES = [0.5000, 0.8413, 0.9772, 0.9987]

# Generate random tensor with values in the range -10 to 10
x_tensor = torch.rand(SHAPE) * 10 - 5
x_train = x_tensor / 5 # Fixed Scaling

# Generating y values using torch operations instead of numpy
y = torch.sin(x_tensor) * 2 + torch.randn(NUM_SAMPLES, 1) * 1
y_train = y.type(torch.float32)
y_train_expanded = y_train.expand(-1, 4)

The training part is almost the same. The batch size is 256.

# Instantiate the model for batch training
model = QuantileNet(output_size=len(QUANTILES))

# Define optimizer for batch training
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training parameters
epochs = 101
batch_size = 256

# Convert the training data into a PyTorch Dataset
dataset = TensorDataset(x_train, y_train_expanded)

# Create a DataLoader to handle batching and shuffling
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop with DataLoader
model.train()
for epoch in range(epochs):
for x_batch, y_batch in dataloader:
optimizer.zero_grad()
preds = model(x_batch)
loss = multi_quantile_loss(preds, y_batch, QUANTILES)
loss.backward()
optimizer.step()

if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')

Plot the results in a graph.

# Set the model to evaluation mode
model.eval()

# Predict the quantiles
with torch.no_grad():
predictions = model(x_train)

# Convert the predictions and x_train to numpy for plotting
x_train_np = x_tensor.numpy().flatten()
y_train_np = y_train.numpy().flatten()
predictions_np = predictions.numpy()

# Plotting
plt.figure(figsize=(12, 6))
plt.scatter(x_train_np, y_train_np, label='Actual Data', color='blue', marker='.', alpha=0.1)
plt.scatter(x_train_np, predictions_np[:, 0], label=f'{QUANTILES[0]:.4f} Percentile', color='green', marker='.', alpha=0.1)
plt.scatter(x_train_np, predictions_np[:, 1], label=f'{QUANTILES[1]:.4f} Percentile', color='red', marker='.', alpha=0.1)
plt.scatter(x_train_np, predictions_np[:, 2], label=f'{QUANTILES[2]:.4f} Percentile', color='purple', marker='.', alpha=0.1)
plt.scatter(x_train_np, predictions_np[:, 3], label=f'{QUANTILES[3]:.4f} Percentile', color='orange', marker='.', alpha=0.1)
plt.title('Predicted Quantiles vs. Actual Data')
plt.xlabel('x_train')
plt.ylabel('y_train and Predicted Quantiles')
plt.legend()
plt.show()

Here is the graph.

Image by author

Whoa, this is unexpectedly good! What a nice surprise! The flattening of the 2σ and 3σ estimates has been suppressed. Why did this happen? Each percentile estimate was obtained simultaneously to detect and mitigate the flattening in a later step. However, this resulted in stopping the flattening of the 2σ and 3σ estimates. This is a pleasant byproduct of simultaneous estimation in a single network. Just to be sure, we also check the ratio of each forecast.

# Plot Differences 
plt.figure(figsize=(12, 6))
plt.scatter(x_train_np, predictions_np[:, 3] - predictions_np[:, 0], label=f'{QUANTILES[3]:.4f} Percentile - {QUANTILES[0]:.4f} Percentile', color='green', marker='.', alpha=0.5)
plt.scatter(x_train_np, predictions_np[:, 2] - predictions_np[:, 0], label=f'{QUANTILES[2]:.4f} Percentile - {QUANTILES[0]:.4f} Percentile', color='red', marker='.', alpha=0.5)
plt.scatter(x_train_np, predictions_np[:, 1] - predictions_np[:, 0], label=f'{QUANTILES[1]:.4f} Percentile - {QUANTILES[0]:.4f} Percentile', color='purple', marker='.', alpha=0.5)
plt.title('Differences in percentile estimates')
plt.xlabel('x_train')
plt.ylabel('differences')
plt.legend()
plt.show()

Let see.

Image by author

Each percentile estimate differences are close to 1, 2, and 3. It seems unnecessary to correct this by following the flowchart shown in the “Mitigation method for the Flattening” section. It should be noted, however, that calculating multiple quantile estimates simultaneously on a single network does not ensure that flattening will not occur. It is recommended to check whether or not flattening has occurred. If flattening has occurred, it should be assessed whether or not irreducible errors follow a normal distribution, and if they do, mitigation methods should be considered.

Summary

This article described quantile loss as a custom loss used in training neural networks, the tendency for predictions to flatten, and its detection and mitigation. It also described the custom loss function when multiple quantile estimates are estimated simultaneously in a single network and the pleasant byproducts of simultaneous estimation. I hope this article will help implement demand forecasting with neural networks.

In addition, I wrote an article on how to estimate uncertainty in neural networks with Monte Carlo dropout. I am sure you will enjoy this one as well.

Did you enjoy this? Click 👏

--

--

Shiro Matsumoto
Shiro Matsumoto

Written by Shiro Matsumoto

Here's something that hasn't been written yet and isn't a copy and paste. Data Scientist in Washington, DC

Responses (2)