Handling ML Predictions in a Flask App

Don't let long-running code slow down your Flask app

Written by Austin Poor (5 min read)
Feb 18, 2021 (4 years ago)
data-sciencepythonflaskmachine-learningcelery
An image of toy robots on an assembly line, generated by DALL-E.
An image of toy robots on an assembly line, generated by DALL-E.

One of the projects I worked on for the Metis data science bootcamp involved creating an MVP of a Flask app to display movie recommendations to a user (think Netflix home screen).

My recommendations involved a model prediction combined with a SQL query — and all of this was being done when a request came in, before the response was sent. Come presentation day, loading the main page of the site took about 30 seconds.

@app.route("/")
def main_page():
  predictions = get_predictions()
  return render_template("index.html", predictions=predictions) 

To be fair, this was an MVP I created on a short deadline in a data science bootcamp; not a web-dev bootcamp. Still, 30 seconds of wait time is not so good.

After graduating, once I had more time, I took a second look at my project to see what I could have improved.

Here are two options I could have explored that would have drastically sped up my page loading time without having to change my prediction algorithm.

1) Load the Page, Then Make Predictions

Instead of making predictions before returning to the main page (like in the code above), separate prediction code from the page response code in the Flask app. Return the page without predictions. Then, once the page is loaded, call the API using JavaScript. Here's what the updated Flask app code would look like:

@app.route("/")
def main_page():
  return render_template("index.html")

@app.route("/api/predict")
def api_predict():
  predictions = get_predictions()
  return jsonify(predictions)

And here's what the JavaScript code would look like:

fetch("/api/predict")
  .then(r => r.json())
  .then(predictions => {
    // Update the page with predicted
    // movie recommendations...
});

This is a small change to the initial code that can make a big difference to the user. The page can initially load with placeholder images or a loading bar so your user can still interact with your site while they wait for the predictions to load.

2) Pass off the Work to Celery

By running slow processes like ML predictions or complex SQL queries in a Flask response function, you're bogging down your Flask server. This might not be a concern for you, depending on how much traffic you're expecting to get. Maybe this is just a POC or only a few people will be using your service at a time. In that case, just using our API method should be fine. Otherwise, you might want to think about a solution that can scale horizontally.

Enter Celery — a python library for creating a “distributed task queue.” With Celery, you can create a pool of workers to handle requests as they come in and it's as easy as adding a decorator to a python function.

With our new Celery workflow, we'll split the API route into two: one route for scheduling a prediction and another for getting the prediction results.

Let's take a look at the updated Flask snippet:

from celery_predict import celery_app
from celery_predict import celery_predict

@app.route("/")
def main_page():
  return render_template("index.html")

@app.route("/api/predict",methods=["POST"])
def api_predict():
    # Get the request data from the post request
    print(request.json)
    data = request.json.get("input-data")
    # Run the task with Celery
    task = celery_predict.delay(data)
    # Return the task's id
    return {"status":"working","result-id":task.id}

@app.route("/api/get-result/<taskid>")
def get_result(taskid: str):
    # Get the celery task
    task = AsyncResult(taskid,app=celery_app)
    # If the task is done, return the result
    # otherwise return the status
    if task.state != "SUCCESS":
        return {"status":task.state}
    return {"status":task.state,"result":task.get()} 

And here's the new Celery snippet:

from celery import Celery

celery_app = Celery(
    "predict",
    broker="redis://redis:6379/0",
    backend="redis://redis:6379/0"
)

@celery_app.task
def celery_predict(input_data: dict):
    # ML prediction code...
    preds = ...
    return preds

And the updated JavaScript:

function checkForResults(rid = "") {
  // Check if result is ready
  fetch(`/api/get-result/${rid}`)
    .then(r => r.json())
    .then(r => {
      const stat = r.status;
      // Check if result is success, failure,
      // or not ready yet
      if (stat == "SUCCESS") {
        handleSuccess(r);
      } else if (stat == "FAILURE") {
        handleError(r);
      } else {
        // Wait and try again
        setTimeout(() => checkForResults(rid), 100);
      }
  });
}

function makePrediction() {
  // Prepare data to make a prediction
  const data = getInputData();
  // Schedule prediction & wait for result to be ready
  postData("/api/predict",{'input-data':data})
    .then(data => data['result-id'])
    .then(rid => checkForResults(rid));	
}

So now we start by loading the page then the JavaScript will schedule the model prediction, and continue to check for results until they're ready.

Admittedly, this has increased the complexity of our solution (you'll need to add a broker like Redis and start up a separate Celery worker process) but it allows us to horizontally scale our app by adding as many Celery workers to our pool as we need.

For a complete example, check out this GitHub repo: a-poor/flask-celery-ml.

***

Recommended Reading