Image Classification with Vertex AI

A step-by-step guide to training and deploying image classification models using Google Vertex AI AutoML Vision

Welcome to the Guide!

This tutorial is designed for developers, data scientists, and students who want to learn how to build image classification models without deep machine learning expertise.

We'll use Google Vertex AI's AutoML Vision, which automates much of the model training process while still delivering high-quality results. No need to write complex neural network architectures!

By the end of this guide, you'll be able to:

  • Prepare image datasets for classification
  • Train custom models with AutoML Vision
  • Evaluate model performance
  • Deploy models to production endpoints
  • Make predictions using the Python SDK

Prerequisites

Google Cloud Account

You'll need a Google Cloud account with billing enabled. Vertex AI is a paid service, but new users get $300 in free credits.

Google Cloud Project

Create a new project or select an existing one in the Google Cloud Console where you'll enable the Vertex AI API.

Vertex AI API Enabled

Enable the Vertex AI API for your project. This can be done in the "APIs & Services" section of the Cloud Console.

Cloud Storage Bucket

Create a Cloud Storage bucket to store your training data. The bucket should be in the same region where you'll train your model.

Python Environment

Set up a Python environment (3.7+) with the Google Cloud SDK installed. We recommend using a virtual environment.

Authentication

Set up authentication by creating a service account and downloading the JSON key file. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable.

Install Required Packages

Install the Google Cloud Vertex AI SDK and other required packages:

pip install google-cloud-aiplatform pandas

Step-by-Step Tutorial

1

Dataset Preparation

For image classification with AutoML Vision, your dataset needs to be structured in a specific way:

Folder Structure:

gs://your-bucket-name/
    ├── train/
    │   ├── class1/
    │   │   ├── image1.jpg
    │   │   ├── image2.jpg
    │   │   └── ...
    │   ├── class2/
    │   │   ├── image1.jpg
    │   │   ├── image2.jpg
    │   │   └── ...
    │   └── ...
    └── test/
        ├── class1/
        ├── class2/
        └── ...

Requirements:

  • Minimum 10 images per class (100+ recommended for better performance)
  • Images should be in JPEG or PNG format
  • Each image should be at least 800x600 pixels
  • Balance your dataset across classes

Upload to Cloud Storage:

Use the Google Cloud Console or gsutil command-line tool to upload your dataset:

gsutil -m cp -r /path/to/local/dataset gs://your-bucket-name
2

Create a Vertex AI Dataset

Now we'll create a dataset resource in Vertex AI that points to your Cloud Storage data.

Using the Python SDK:

from google.cloud import aiplatform

# Initialize the Vertex AI client
aiplatform.init(project="your-project-id", location="us-central1")

# Create an image dataset
dataset = aiplatform.ImageDataset.create(
    display_name="flowers-classification",
    gcs_source="gs://your-bucket-name/train/**",
    import_schema_uri=aiplatform.schema.dataset.ioformat.image.classification.single_label,
)

print(f"Created dataset: {dataset.resource_name}")

Alternative: Using the Console

  1. Go to the Vertex AI section in Google Cloud Console
  2. Navigate to "Datasets" and click "Create"
  3. Select "Image classification (Single-label)"
  4. Enter a name and select your region
  5. Choose "Select import files from Cloud Storage" and enter your path (gs://your-bucket-name/train/**)
  6. Click "Create"
3

Train the AutoML Model

With your dataset ready, you can now train an AutoML Vision model. This process will automatically:

  • Split your data into training/validation sets
  • Select the best model architecture
  • Tune hyperparameters
  • Train and evaluate the model

Using the Python SDK:

# Define training job
training_job = aiplatform.AutoMLImageTrainingJob(
    display_name="train-flowers-classification",
    prediction_type="classification",
    multi_label=False,
    model_type="CLOUD",
)

# Run the training job
model = training_job.run(
    dataset=dataset,
    training_fraction_split=0.8,
    validation_fraction_split=0.1,
    test_fraction_split=0.1,
    budget_milli_node_hours=8000,  # 8 compute hours
    disable_early_stopping=False,
)

print(f"Training completed. Model: {model.resource_name}")

Training Considerations:

  • Budget: More compute hours generally lead to better models (default is 8 hours)
  • Model Type: "CLOUD" for best accuracy, "MOBILE" for edge deployment
  • Monitoring: Track progress in the Vertex AI Console
4

Evaluate the Model

After training completes, you'll want to evaluate the model's performance before deployment.

View Evaluation Metrics:

# Get evaluation metrics
evaluation = model.evaluate()

print("Model evaluation metrics:")
print(f"Precision: {evaluation.metrics['precision']}")
print(f"Recall: {evaluation.metrics['recall']}")
print(f"F1 Score: {evaluation.metrics['f1Score']}")
print(f"Confusion Matrix: {evaluation.metrics['confusionMatrix']}")

Key Metrics to Check:

  • Precision: Percentage of correct positive predictions
  • Recall: Percentage of actual positives correctly identified
  • F1 Score: Harmonic mean of precision and recall
  • Confusion Matrix: Shows performance per class

Console Visualization:

For a more visual evaluation, check the "Evaluate" tab in the Vertex AI Console where you can see:

  • Precision-recall curves
  • Confusion matrix visualization
  • Example predictions with confidence scores
5

Deploy the Model

To make predictions, you need to deploy your model to an endpoint. This creates a scalable service that can handle prediction requests.

Using the Python SDK:

# Create an endpoint
endpoint = aiplatform.Endpoint.create(
    display_name="flowers-classification-endpoint",
    project="your-project-id",
    location="us-central1",
)

# Deploy the model to the endpoint
endpoint.deploy(
    model=model,
    deployed_model_display_name="flowers-classification-model",
    traffic_percentage=100,
    machine_type="n1-standard-4",  # Choose appropriate machine type
    min_replica_count=1,
    max_replica_count=1,
)

print(f"Model deployed to endpoint: {endpoint.resource_name}")

Deployment Considerations:

  • Machine Type: Choose based on expected traffic (n1-standard-2 for testing, larger for production)
  • Scaling: Set min/max replicas for automatic scaling
  • Cost: You're billed while the endpoint is running
  • Undeploy: Remember to undeploy when not in use to avoid charges
6

Make Predictions

With your model deployed to an endpoint, you can now make predictions on new images.

Using the Python SDK:

import base64

# Function to encode image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

# Example prediction
image_path = "path/to/your/test_image.jpg"
encoded_image = encode_image(image_path)

# Make prediction
prediction = endpoint.predict(
    instances=[{"content": encoded_image}],
    parameters={"confidenceThreshold": 0.5},  # Minimum confidence score
)

# Process results
for result in prediction.predictions:
    print("Predicted classes:")
    for i, (label, score) in enumerate(zip(result["displayNames"], result["confidences"])):
        print(f"{i+1}. {label}: {score:.2%}")

Alternative: Batch Prediction

For predicting on many images at once, use batch prediction:

# Create batch prediction job
batch_job = model.batch_predict(
    job_display_name="batch-pred-flowers",
    gcs_source="gs://your-bucket-name/test/**",
    gcs_destination_prefix="gs://your-bucket-name/predictions/",
    instances_format="jsonl",
    predictions_format="jsonl",
)

print(f"Batch prediction job: {batch_job.resource_name}")

Prediction Options:

  • Online Prediction: Low-latency requests to the endpoint (good for real-time applications)
  • Batch Prediction: Process many images at once (good for offline processing)
  • Confidence Threshold: Filter predictions by minimum confidence score

Made with DeepSite LogoDeepSite - 🧬 Remix