JAX: More Than Just an Intro
Staying competitive often means leveraging the right technologies to solve complex problems. One such technology that’s rising in the machine learning community is JAX. But JAX isn’t just for academic research or building the latest neural networks—it has practical applications that can drive significant improvements in business operations.
In this post, we’ll explore how JAX can be applied to a common business challenge: inventory optimization. We’ll start with the basics and gradually build up to a comprehensive solution, exploring JAX’s capabilities along the way.
What is JAX?
JAX is a powerful library for high-performance numerical computing and machine learning. Developed by Google, JAX combines the ease of use of NumPy with the power of automatic differentiation and just-in-time (JIT) compilation. Here are some of the features that come with JAX:
-
NumPy-like API: If you’re familiar with NumPy, you’ll feel right at home with JAX. It provides a similar interface for array operations through
jax.numpy
(often imported asjnp
), making it easy to get started. -
Automatic Differentiation: JAX can automatically compute derivatives of your functions using
jax.grad()
. This is crucial for gradient-based optimization, which is at the heart of many machine learning algorithms and optimization techniques. -
JIT Compilation: With
jax.jit()
, JAX can compile your Python functions to optimized machine code, significantly speeding up execution. -
Vectorization: JAX’s
jax.vmap()
allows for efficient vectorization of operations, enabling you to easily apply functions over batches of data. -
Functional Programming Model: JAX encourages a functional programming style, avoiding in-place mutations of arrays. This leads to more predictable and easier-to-optimize code.
-
Hardware Acceleration: JAX can leverage hardware accelerators like GPUs and TPUs, enabling blazing-fast computations on large datasets.
Now that we have a basic understanding of JAX, let’s dive into our business problem: inventory optimization.
The Business Problem: Inventory Optimization
Imagine you’re running a retail company that’s struggling with inventory management. You need to balance two competing goals:
- Having enough stock to meet customer demand and avoid stockouts.
- Minimizing excess inventory to reduce holding costs and the risk of obsolescence.
This is a classic optimization problem, and it’s one where JAX’s capabilities can really shine. By leveraging JAX, we can build a demand forecasting model and optimize our inventory levels across multiple stores and products.
In the following sections, we’ll walk through the process of building this solution, step by step. We’ll start with a simple demand forecasting model and gradually increase its complexity, showcasing different JAX features along the way.
Let’s get started by setting up our environment and diving into some code!
Setting Up the Environment
Before we dive into our inventory optimization problem, let’s set up our environment and import the necessary libraries. We’ll be using JAX, along with some helper libraries for data manipulation and visualization.
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import matplotlib.pyplot as plt
# Enable 64-bit precision for more accurate calculations
jax.config.update("jax_enable_x64", True)
# Check if GPU is available
print("GPU available:", jax.default_backend() == "gpu")
With our environment set up, let’s create a simple demand forecasting model. We’ll start with a linear regression model that predicts demand based on a single feature (e.g., time).
Implementing a Simple Demand Forecasting Model
Let’s define our model parameters and functions:
# Define the model parameters as a dictionary (pytree)
def init_params(slope=0.0, intercept=0.0):
return {"slope": slope, "intercept": intercept}
# Define the model
def predict(params, x):
return params["slope"] * x + params["intercept"]
# Define the loss function (Mean Squared Error)
def mse_loss(params, x, y):
y_pred = predict(params, x)
return jnp.mean((y_pred - y) ** 2)
# Define the gradient function
grad_mse_loss = jit(grad(mse_loss))
# Define the update function
@jit
def update(params, x, y, learning_rate=0.01):
grads = grad_mse_loss(params, x, y)
return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
Now, let’s generate some sample data and train our model:
# Generate sample data
key = jax.random.PRNGKey(0)
x = jnp.linspace(0, 10, 100)
true_slope, true_intercept = 2.5, 5.0
y = true_slope * x + true_intercept + jax.random.normal(key, x.shape) * 2
# Initialize parameters
params = init_params()
# Train the model
num_epochs = 200
for epoch in range(num_epochs):
params = update(params, x, y)
# Print final parameters
print(f"Trained slope: {params['slope']:.4f}, Trained intercept: {params['intercept']:.4f}")
print(f"True slope: {true_slope:.4f}, True intercept: {true_intercept:.4f}")
# Visualize the results
plt.scatter(x, y, label='Data')
plt.plot(x, predict(params, x), color='red', label='Prediction')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Demand')
plt.title('Simple Demand Forecasting Model')
plt.show()
Trained slope: 2.7733, Trained intercept: 3.0959 True slope: 2.5000, True intercept: 5.0000
In this example, we’ve created a simple linear regression model to forecast demand based on time. We’ve used JAX’s automatic differentiation (grad
) to compute the gradients of our loss function, and we’ve used jit
to compile our update function for faster execution.
This simple model demonstrates several key features of JAX:
- NumPy-like API: We used
jax.numpy
(imported asjnp
) for array operations, which should feel familiar to NumPy users. - Automatic Differentiation: We used
jax.grad
to automatically compute the gradient of our loss function. - JIT Compilation: We used
@jit
to compile ourupdate
function, which can significantly speed up execution, especially for more complex models. - Functional Programming: Notice how we return a new parameter dictionary in our
update
function, rather than modifying the existing one. This aligns with JAX’s functional programming model. - Pytrees: We used a simple dictionary as a pytree to represent our model parameters, which JAX can easily work with.
While this simple model is a good starting point, it has some limitations for real-world inventory optimization:
- It only considers a single feature (time) and assumes a linear relationship with demand.
- It doesn’t account for seasonal patterns or other complex trends in demand.
- It doesn’t consider multiple products or stores, which is crucial for real-world inventory management.
In the next section, we’ll expand on this model to address these limitations and create a more robust inventory optimization solution.
Expanding the Model for Multiple Products and Stores
Now that we have a basic understanding of how to use JAX for demand forecasting, let’s expand our model to handle multiple products across multiple stores. This will bring us closer to a real-world inventory optimization scenario.
We’ll make the following enhancements:
- Support multiple products and stores
- Introduce seasonality in our demand model
- Use more features for prediction (time, store-specific features, product-specific features)
- Implement a more sophisticated optimization strategy
Let’s build this model step by step.
Step 1: Set up the environment
First, let’s import the necessary libraries and set up our environment:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import matplotlib.pyplot as plt
# Enable 64-bit precision for more accurate calculations
jax.config.update("jax_enable_x64", True)
This setup is similar to our previous example, but we’ll be making more extensive use of JAX’s features as we build our expanded model.
Step 2: Define model parameters
Now, let’s define a function to initialize our model parameters:
def init_params(n_products, n_stores, n_features):
return {
"weights": jax.random.normal(jax.random.PRNGKey(0), (n_products, n_features)),
"bias": jax.random.normal(jax.random.PRNGKey(1), (n_products,)),
"seasonal_factors": jax.random.normal(jax.random.PRNGKey(2), (12, n_products)) # Monthly seasonality
}
This function initializes our model parameters:
weights
: A matrix of weights for each product and featurebias
: A bias term for each productseasonal_factors
: Monthly seasonal factors for each product
Notice how we’re using JAX’s random number generation to initialize these parameters. The use of different PRNGKeys ensures that we get different random values for each parameter.
Step 3: Define the prediction function
Next, let’s define our prediction function:
def predict(params, features, month):
base_demand = jnp.dot(features, params["weights"].T) + params["bias"]
seasonal_factor = params["seasonal_factors"][month]
return base_demand * seasonal_factor
This function predicts demand based on:
- The input features
- The learned weights and biases
- The seasonal factor for the given month
The use of jnp.dot
allows us to efficiently compute the base demand for all products and stores at once. We then apply the seasonal factor to get our final prediction.
Step 4: Define the loss function and its gradient
Now, let’s define our loss function and its gradient:
def mse_loss(params, features, month, actual_demand):
predicted_demand = predict(params, features, month)
return jnp.mean((predicted_demand - actual_demand) ** 2)
grad_mse_loss = jit(grad(mse_loss))
We’re using Mean Squared Error (MSE) as our loss function. The grad_mse_loss
function is created by applying jax.grad
to our loss function and then JIT-compiling it for efficiency.
Step 5: Define the update function
Let’s define the function to update our model parameters:
@jit
def update(params, features, month, actual_demand, learning_rate=0.01):
grads = grad_mse_loss(params, features, month, actual_demand)
return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
This function computes the gradients of our loss with respect to the parameters and then updates the parameters using gradient descent. The use of jax.tree_map
allows us to easily apply this update to our entire parameter tree.
Step 6: Generate sample data
To test our model, we need some sample data. Let’s create a function to generate this:
def generate_sample_data(n_samples, n_products, n_stores, n_features):
key = jax.random.PRNGKey(0)
features = jax.random.normal(key, (n_samples, n_stores, n_features))
true_weights = jax.random.normal(jax.random.PRNGKey(1), (n_products, n_features))
true_bias = jax.random.normal(jax.random.PRNGKey(2), (n_products,))
true_seasonal_factors = jnp.abs(jax.random.normal(jax.random.PRNGKey(3), (12, n_products))) + 0.5
months = jax.random.randint(jax.random.PRNGKey(4), (n_samples,), 0, 12)
base_demand = jnp.dot(features, true_weights.T) + true_bias
seasonal_factors = true_seasonal_factors[months]
actual_demand = base_demand * seasonal_factors[:, None, :] + jax.random.normal(jax.random.PRNGKey(5), (n_samples, n_stores, n_products)) * 0.1
return features, months, actual_demand
This function generates:
- Random features for each sample, store, and feature
- True weights, biases, and seasonal factors (which our model will try to learn)
- Random months for each sample
- Actual demand based on these parameters, with some added noise
Step 7: Create the training loop
Now, let’s define our training loop:
def train_model(params, features, months, actual_demand, num_epochs=1000):
for epoch in range(num_epochs):
params = update(params, features, months, actual_demand)
if epoch % 100 == 0:
loss = mse_loss(params, features, months, actual_demand)
print(f"Epoch {epoch}, Loss: {loss:.4f}")
return params
This function iteratively updates our model parameters using the update
function we defined earlier. It also prints out the loss every 100 epochs so we can monitor the training progress.
Step 8: Run the model
Finally, let’s put it all together and run our model:
# Main execution
n_samples, n_products, n_stores, n_features = 1000, 5, 3, 4
features, months, actual_demand = generate_sample_data(n_samples, n_products, n_stores, n_features)
params = init_params(n_products, n_stores, n_features)
trained_params = train_model(params, features, months, actual_demand)
# Visualize results for a single product and store
product_idx, store_idx = 0, 0
predicted_demand = vmap(lambda f, m: predict(trained_params, f, m)[:, store_idx, product_idx])(features, months)
plt.figure(figsize=(12, 6))
plt.scatter(range(n_samples), actual_demand[:, store_idx, product_idx], alpha=0.5, label='Actual Demand')
plt.scatter(range(n_samples), predicted_demand, alpha=0.5, label='Predicted Demand')
plt.xlabel('Sample')
plt.ylabel('Demand')
plt.title(f'Actual vs Predicted Demand for Product {product_idx} in Store {store_idx}')
plt.legend()
plt.show()
This code:
- Generates sample data
- Initializes model parameters
- Trains the model
- Uses
vmap
to efficiently predict demand for all samples - Visualizes the results for a single product and store
The use of vmap
here is particularly noteworthy. It allows us to vectorize our prediction function over all samples, greatly improving efficiency.
This expanded model introduces several new concepts and JAX features:
-
Multi-dimensional data: We’re now handling data for multiple products across multiple stores, which requires more complex array operations.
-
Seasonality: We’ve introduced seasonal factors to capture monthly variations in demand.
-
Vectorization: We use
vmap
to efficiently apply our prediction function across multiple samples. -
More complex pytrees: Our
params
dictionary now contains more complex nested structures, demonstrating JAX’s ability to work with arbitrary pytrees. -
Batch processing: Instead of training on individual data points, we’re now processing batches of data, which is more efficient and realistic for large datasets.
The model now predicts demand based on store-specific features, product-specific weights, and seasonal factors. This allows it to capture more complex patterns in the data, making it more suitable for real-world inventory optimization scenarios.
In a real-world application, you would typically split your data into training and validation sets, implement early stopping to prevent overfitting, and perhaps use more sophisticated optimization techniques like Adam instead of simple gradient descent.
Next, we could further enhance this model by:
- Implementing a more sophisticated inventory optimization strategy based on the demand predictions.
- Adding constraints for warehouse capacity, budget limitations, or product shelf life.
- Incorporating external factors like promotions, competitors’ actions, or economic indicators.
These enhancements would bring us closer to a comprehensive inventory management system that could significantly improve business operations.
QUVO AI Blog © 2024