1

I've been trying to deploy a very simple toy Keras model to Cloud Functions, which would predict the class of an image, but for reasons unknown, when the execution gets to the predict method, it gets stuck, does not throw any error, and eventually times out.

import functions_framework
import io
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from PIL import Image

model = load_model("gs://<my-bucket>/cifar10_model.keras")

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def preprocess_image(image_file):
    img = Image.open(io.BytesIO(image_file.read()))
    img = img.resize((32, 32))
    img = np.array(img)
    img = img / 255.0
    img = img.reshape(1, 32, 32, 3)
    return img

@functions_framework.http
def predict(request):
    image = preprocess_image(request.files['image_file'])
    print(image.shape) # this prints OK
    prediction = model.predict(image)
    print(prediction) # this never prints
    predicted_class = class_names[np.argmax(prediction)]
    return f"Predicted class: {predicted_class}"

Debugging locally works fine, the prediction is fast as expected (the model weights file is 2MB). I also added several prints along the way (removed from the snippet above) and the execution works fine until the predict method.

Even though the minimal compute configuration should work, I tried reserving more memory and CPU, but nothing worked. The model is hosted at Storage, I tried downloading it first, but that didn't work either. I did also try making the prediction inside a tf.device('/cpu:0') context, passing a step=1 parameter and converting the image array to a Keras Dataset first, as suggested by ChatGPT, with the same results. Actually, nothing prints as a result of invoking predict at all. Calling call instead of predict got me nowhere.

What am I missing?

3 Answers 3

1
+50

I suggest downloading the model to local storage (/tmp) and load it once at the cold start. Here’s an overview of the code:

import os
import tempfile # add these imports after [from PIL import Image]

model = None
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def get_model():
    global model
    if model is None: 
        local_model_path = os.path.join(tmpfile.gettmpdir(), "cifar10_model.keras") # download the model only once to /tmp directory
        if not os.path.exists(local_model_path):
            print("Downloading model...")
            tf.io.gfile.copy("gs://<my-bucket>/cifar10_model.keras", local_model_path, overwrite=True)
        print("Loading model...")
        model = load_model(local_model_path)
    return model

def preprocess_image(image_file):
    img = Image.open(io.BytesIO(image_file.read()))
    img = img.resize((32, 32))
    img = np.array(img) / 255.0
    img = img.reshape(1, 32, 32, 3)
    return img

@functions_framework.http
def predict(request):
    global model
    model = get_model() # get or load the model
    image = preprocess_image(request.files['image_file'])
    print("Image shape:", image.shape)
    prediction = model.predict(image)
    print("Prediction:", prediction) 
    predicted_class = class_names[np.argmax(prediction)]
    return f"Predicted class: {predicted_class}"

Hope this works on your end!

Sign up to request clarification or add additional context in comments.

1 Comment

I appreciate your help, it worked indeed, but I was invested into writing as little code as possible. So, based on your suggestion, I came up with my own solution, which I am posting below. In the ending, downloading the model is not the issue, but we must instantiate the model inside the predict function. Nevertheless, thanks a lot!
0

Even though the previous suggestion worked, I was invested into writing as little code as possible. Therefore, I came up with this solution:

import functions_framework
import io
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from PIL import Image

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def preprocess_image(image_file):
    img = Image.open(io.BytesIO(image_file.read()))
    img = img.resize((32, 32))
    img = np.array(img) / 255.0
    img = img.reshape(1, 32, 32, 3)
    return img

@functions_framework.http
def predict(request):
    model = load_model("gs://<my-bucket>/cifar10_model.keras")
    image = preprocess_image(request.files['image_file'])
    prediction = model.predict(image)
    predicted_class = class_names[np.argmax(prediction)]
    return f"Predicted class: {predicted_class}"

It seems that the issue is that the model must be instantiated inside the function that processes the request. I would have never guessed!

Comments

0

The thing is the your function stuck when trying to predict and that issue is not in your code but with the how google cloud environment works. Generally cloud functions have certain limits like memory, time, specifically where and how files are loaded and processed and etc....

Try out these steps:

  1. Download the model file to a temporary folder inside the function before using it.
from google.cloud import storage
import tensorflow as tf

def load_model_from_cloud(bucket_name, model_file):
    # Connect to source (in your case it's cloud storage bucket)
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(model_file)

    # Download the model to a temporary folder or any other you like
    local_path = f"/tmp/{model_file}"
    blob.download_to_filename(local_path)

    return tf.keras.models.load_model(local_path)

model = load_model_from_cloud("<your-bucket-name>", "cifar10_model.keras")
  1. Add debugging (Extra Logs) for identifying what's exactly going on under the hood. That will give more clarity about your problem.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.