JAX: More Than Just an Intro

AIIntroductionJAX

Staying competitive often means leveraging the right technologies to solve complex problems. One such technology that’s rising in the machine learning community is JAX. But JAX isn’t just for academic research or building the latest neural networks—it has practical applications that can drive significant improvements in business operations.

In this post, we’ll explore how JAX can be applied to a common business challenge: inventory optimization. We’ll start with the basics and gradually build up to a comprehensive solution, exploring JAX’s capabilities along the way.

What is JAX?

JAX is a powerful library for high-performance numerical computing and machine learning. Developed by Google, JAX combines the ease of use of NumPy with the power of automatic differentiation and just-in-time (JIT) compilation. Here are some of the features that come with JAX:

  1. NumPy-like API: If you’re familiar with NumPy, you’ll feel right at home with JAX. It provides a similar interface for array operations through jax.numpy (often imported as jnp), making it easy to get started.

  2. Automatic Differentiation: JAX can automatically compute derivatives of your functions using jax.grad(). This is crucial for gradient-based optimization, which is at the heart of many machine learning algorithms and optimization techniques.

  3. JIT Compilation: With jax.jit(), JAX can compile your Python functions to optimized machine code, significantly speeding up execution.

  4. Vectorization: JAX’s jax.vmap() allows for efficient vectorization of operations, enabling you to easily apply functions over batches of data.

  5. Functional Programming Model: JAX encourages a functional programming style, avoiding in-place mutations of arrays. This leads to more predictable and easier-to-optimize code.

  6. Hardware Acceleration: JAX can leverage hardware accelerators like GPUs and TPUs, enabling blazing-fast computations on large datasets.

Now that we have a basic understanding of JAX, let’s dive into our business problem: inventory optimization.

The Business Problem: Inventory Optimization

Imagine you’re running a retail company that’s struggling with inventory management. You need to balance two competing goals:

  1. Having enough stock to meet customer demand and avoid stockouts.
  2. Minimizing excess inventory to reduce holding costs and the risk of obsolescence.

This is a classic optimization problem, and it’s one where JAX’s capabilities can really shine. By leveraging JAX, we can build a demand forecasting model and optimize our inventory levels across multiple stores and products.

In the following sections, we’ll walk through the process of building this solution, step by step. We’ll start with a simple demand forecasting model and gradually increase its complexity, showcasing different JAX features along the way.

Let’s get started by setting up our environment and diving into some code!

Setting Up the Environment

Before we dive into our inventory optimization problem, let’s set up our environment and import the necessary libraries. We’ll be using JAX, along with some helper libraries for data manipulation and visualization.

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import matplotlib.pyplot as plt
 
# Enable 64-bit precision for more accurate calculations
jax.config.update("jax_enable_x64", True)
 
# Check if GPU is available
print("GPU available:", jax.default_backend() == "gpu")

With our environment set up, let’s create a simple demand forecasting model. We’ll start with a linear regression model that predicts demand based on a single feature (e.g., time).

Implementing a Simple Demand Forecasting Model

Let’s define our model parameters and functions:

# Define the model parameters as a dictionary (pytree)
def init_params(slope=0.0, intercept=0.0):
    return {"slope": slope, "intercept": intercept}

# Define the model
def predict(params, x):
    return params["slope"] * x + params["intercept"]

# Define the loss function (Mean Squared Error)
def mse_loss(params, x, y):
    y_pred = predict(params, x)
    return jnp.mean((y_pred - y) ** 2)

# Define the gradient function
grad_mse_loss = jit(grad(mse_loss))

# Define the update function
@jit
def update(params, x, y, learning_rate=0.01):
    grads = grad_mse_loss(params, x, y)
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

Now, let’s generate some sample data and train our model:

# Generate sample data
key = jax.random.PRNGKey(0)
x = jnp.linspace(0, 10, 100)
true_slope, true_intercept = 2.5, 5.0
y = true_slope * x + true_intercept + jax.random.normal(key, x.shape) * 2

# Initialize parameters
params = init_params()

# Train the model
num_epochs = 200
for epoch in range(num_epochs):
    params = update(params, x, y)

# Print final parameters
print(f"Trained slope: {params['slope']:.4f}, Trained intercept: {params['intercept']:.4f}")
print(f"True slope: {true_slope:.4f}, True intercept: {true_intercept:.4f}")

# Visualize the results
plt.scatter(x, y, label='Data')
plt.plot(x, predict(params, x), color='red', label='Prediction')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Demand')
plt.title('Simple Demand Forecasting Model')
plt.show()

Trained slope: 2.7733, Trained intercept: 3.0959 True slope: 2.5000, True intercept: 5.0000 Simple Demand Forecasting Model

In this example, we’ve created a simple linear regression model to forecast demand based on time. We’ve used JAX’s automatic differentiation (grad) to compute the gradients of our loss function, and we’ve used jit to compile our update function for faster execution.

This simple model demonstrates several key features of JAX:

  1. NumPy-like API: We used jax.numpy (imported as jnp) for array operations, which should feel familiar to NumPy users.
  2. Automatic Differentiation: We used jax.grad to automatically compute the gradient of our loss function.
  3. JIT Compilation: We used @jit to compile our update function, which can significantly speed up execution, especially for more complex models.
  4. Functional Programming: Notice how we return a new parameter dictionary in our update function, rather than modifying the existing one. This aligns with JAX’s functional programming model.
  5. Pytrees: We used a simple dictionary as a pytree to represent our model parameters, which JAX can easily work with.

While this simple model is a good starting point, it has some limitations for real-world inventory optimization:

  1. It only considers a single feature (time) and assumes a linear relationship with demand.
  2. It doesn’t account for seasonal patterns or other complex trends in demand.
  3. It doesn’t consider multiple products or stores, which is crucial for real-world inventory management.

In the next section, we’ll expand on this model to address these limitations and create a more robust inventory optimization solution.

Expanding the Model for Multiple Products and Stores

Now that we have a basic understanding of how to use JAX for demand forecasting, let’s expand our model to handle multiple products across multiple stores. This will bring us closer to a real-world inventory optimization scenario.

We’ll make the following enhancements:

  1. Support multiple products and stores
  2. Introduce seasonality in our demand model
  3. Use more features for prediction (time, store-specific features, product-specific features)
  4. Implement a more sophisticated optimization strategy

Let’s build this model step by step.

Step 1: Set up the environment

First, let’s import the necessary libraries and set up our environment:

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import matplotlib.pyplot as plt
 
# Enable 64-bit precision for more accurate calculations
jax.config.update("jax_enable_x64", True)

This setup is similar to our previous example, but we’ll be making more extensive use of JAX’s features as we build our expanded model.

Step 2: Define model parameters

Now, let’s define a function to initialize our model parameters:

def init_params(n_products, n_stores, n_features):
    return {
        "weights": jax.random.normal(jax.random.PRNGKey(0), (n_products, n_features)),
        "bias": jax.random.normal(jax.random.PRNGKey(1), (n_products,)),
        "seasonal_factors": jax.random.normal(jax.random.PRNGKey(2), (12, n_products))  # Monthly seasonality
    }

This function initializes our model parameters:

Notice how we’re using JAX’s random number generation to initialize these parameters. The use of different PRNGKeys ensures that we get different random values for each parameter.

Step 3: Define the prediction function

Next, let’s define our prediction function:

def predict(params, features, month):
    base_demand = jnp.dot(features, params["weights"].T) + params["bias"]
    seasonal_factor = params["seasonal_factors"][month]
    return base_demand * seasonal_factor

This function predicts demand based on:

The use of jnp.dot allows us to efficiently compute the base demand for all products and stores at once. We then apply the seasonal factor to get our final prediction.

Step 4: Define the loss function and its gradient

Now, let’s define our loss function and its gradient:

def mse_loss(params, features, month, actual_demand):
    predicted_demand = predict(params, features, month)
    return jnp.mean((predicted_demand - actual_demand) ** 2)
 
grad_mse_loss = jit(grad(mse_loss))

We’re using Mean Squared Error (MSE) as our loss function. The grad_mse_loss function is created by applying jax.grad to our loss function and then JIT-compiling it for efficiency.

Step 5: Define the update function

Let’s define the function to update our model parameters:

@jit
def update(params, features, month, actual_demand, learning_rate=0.01):
    grads = grad_mse_loss(params, features, month, actual_demand)
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

This function computes the gradients of our loss with respect to the parameters and then updates the parameters using gradient descent. The use of jax.tree_map allows us to easily apply this update to our entire parameter tree.

Step 6: Generate sample data

To test our model, we need some sample data. Let’s create a function to generate this:

def generate_sample_data(n_samples, n_products, n_stores, n_features):
    key = jax.random.PRNGKey(0)
    features = jax.random.normal(key, (n_samples, n_stores, n_features))
    true_weights = jax.random.normal(jax.random.PRNGKey(1), (n_products, n_features))
    true_bias = jax.random.normal(jax.random.PRNGKey(2), (n_products,))
    true_seasonal_factors = jnp.abs(jax.random.normal(jax.random.PRNGKey(3), (12, n_products))) + 0.5
 
    months = jax.random.randint(jax.random.PRNGKey(4), (n_samples,), 0, 12)
    base_demand = jnp.dot(features, true_weights.T) + true_bias
    seasonal_factors = true_seasonal_factors[months]
    actual_demand = base_demand * seasonal_factors[:, None, :] + jax.random.normal(jax.random.PRNGKey(5), (n_samples, n_stores, n_products)) * 0.1
 
    return features, months, actual_demand

This function generates:

Step 7: Create the training loop

Now, let’s define our training loop:

def train_model(params, features, months, actual_demand, num_epochs=1000):
    for epoch in range(num_epochs):
        params = update(params, features, months, actual_demand)
        if epoch % 100 == 0:
            loss = mse_loss(params, features, months, actual_demand)
            print(f"Epoch {epoch}, Loss: {loss:.4f}")
    return params

This function iteratively updates our model parameters using the update function we defined earlier. It also prints out the loss every 100 epochs so we can monitor the training progress.

Step 8: Run the model

Finally, let’s put it all together and run our model:

# Main execution
n_samples, n_products, n_stores, n_features = 1000, 5, 3, 4
features, months, actual_demand = generate_sample_data(n_samples, n_products, n_stores, n_features)
 
params = init_params(n_products, n_stores, n_features)
trained_params = train_model(params, features, months, actual_demand)
 
# Visualize results for a single product and store
product_idx, store_idx = 0, 0
predicted_demand = vmap(lambda f, m: predict(trained_params, f, m)[:, store_idx, product_idx])(features, months)
 
plt.figure(figsize=(12, 6))
plt.scatter(range(n_samples), actual_demand[:, store_idx, product_idx], alpha=0.5, label='Actual Demand')
plt.scatter(range(n_samples), predicted_demand, alpha=0.5, label='Predicted Demand')
plt.xlabel('Sample')
plt.ylabel('Demand')
plt.title(f'Actual vs Predicted Demand for Product {product_idx} in Store {store_idx}')
plt.legend()
plt.show()

This code:

  1. Generates sample data
  2. Initializes model parameters
  3. Trains the model
  4. Uses vmap to efficiently predict demand for all samples
  5. Visualizes the results for a single product and store

The use of vmap here is particularly noteworthy. It allows us to vectorize our prediction function over all samples, greatly improving efficiency.

This expanded model introduces several new concepts and JAX features:

  1. Multi-dimensional data: We’re now handling data for multiple products across multiple stores, which requires more complex array operations.

  2. Seasonality: We’ve introduced seasonal factors to capture monthly variations in demand.

  3. Vectorization: We use vmap to efficiently apply our prediction function across multiple samples.

  4. More complex pytrees: Our params dictionary now contains more complex nested structures, demonstrating JAX’s ability to work with arbitrary pytrees.

  5. Batch processing: Instead of training on individual data points, we’re now processing batches of data, which is more efficient and realistic for large datasets.

The model now predicts demand based on store-specific features, product-specific weights, and seasonal factors. This allows it to capture more complex patterns in the data, making it more suitable for real-world inventory optimization scenarios.

In a real-world application, you would typically split your data into training and validation sets, implement early stopping to prevent overfitting, and perhaps use more sophisticated optimization techniques like Adam instead of simple gradient descent.

Next, we could further enhance this model by:

  1. Implementing a more sophisticated inventory optimization strategy based on the demand predictions.
  2. Adding constraints for warehouse capacity, budget limitations, or product shelf life.
  3. Incorporating external factors like promotions, competitors’ actions, or economic indicators.

These enhancements would bring us closer to a comprehensive inventory management system that could significantly improve business operations.

QUVO AI Blog © 2024