Train a TensorFlow model


This tutorial demonstrates how run a TensorFlow job at scale using Azure ML. You will train a TensorFlow model to classify handwritten digits (MNIST) using a deep neural network (DNN) and log your results to the Azure ML service.


If you don’t have access to an Azure ML workspace, follow the setup tutorial to configure and create a workspace.

Set up development environment

The setup for your development work in this tutorial includes the following actions:

Import azuremlsdk package


Load your workspace

Instantiate a workspace object from your existing workspace. The following code will load the workspace details from a config.json file if you previously wrote one out with write_workspace_config().

ws <- load_workspace_from_config()

Or, you can retrieve a workspace by directly specifying your workspace details:

ws <- get_workspace("<your workspace name>", "<your subscription ID>", "<your resource group>")

Create an experiment

An Azure ML experiment tracks a grouping of runs, typically from the same training script. Create an experiment to track the runs for training the TensorFlow model on the MNIST data.

exp <- experiment(workspace = ws, name = "tf-mnist")

If you would like to track your runs in an existing experiment, simply specify that experiment’s name to the name parameter of experiment().

Create a compute target

By using Azure Machine Learning Compute (AmlCompute), a managed service, data scientists can train machine learning models on clusters of Azure virtual machines. In this tutorial, you create a GPU-enabled cluster as your training environment. The code below creates the compute cluster for you if it doesn’t already exist in your workspace.

You may need to wait a few minutes for your compute cluster to be provisioned if it doesn’t already exist.

cluster_name <- "gpucluster"

compute_target <- get_compute(ws, cluster_name = cluster_name)
if (is.null(compute_target))
  vm_size <- "STANDARD_NC6"
  compute_target <- create_aml_compute(workspace = ws, 
                                       cluster_name = cluster_name,
                                       vm_size = vm_size, 
                                       max_nodes = 4)
  wait_for_provisioning_completion(compute_target, show_output = TRUE)

Prepare the training script

A training script called tf_mnist.R has been provided for you in the train-with-tensorflow/ subfolder of this vignette. The Azure ML SDK provides a set of logging APIs for logging various metrics during training runs. These metrics are recorded and persisted in the experiment run record, and can be be accessed at any time or viewed in the run details page in Azure Machine Learning studio.

In order to collect and upload run metrics, you need to do the following inside the training script:

                  feed_dict = dict(x = mnist$test$images, y_ = mnist$test$labels)))

See the reference for the full set of logging methods log_*() available from the R SDK.

Create an estimator

An Azure ML estimator encapsulates the run configuration information needed for executing a training script on the compute target. Azure ML runs are run as containerized jobs on the specified compute target.

To create the estimator, define the following:

env <- r_environment("tensorflow-env", custom_docker_image = "amlsamples/r-tensorflow:latest")

est <- estimator(source_directory = "train-with-tensorflow",
                 entry_script = "tf_mnist.R",
                 compute_target = compute_target,
                 environment = env)

Submit the job

Finally submit the job to run on your cluster. submit_experiment() returns a Run object that you can then use to interface with the run.

run <- submit_experiment(exp, est)

You can view the run’s details as a table. Clicking the “Web View” link provided will bring you to Azure Machine Learning studio, where you can monitor the run in the UI.


Model training happens in the background. Wait until the model has finished training before you run more code.

wait_for_run_completion(run, show_output = TRUE)

View run metrics

Once your job has finished, you can view the metrics collected during your TensorFlow run.

metrics <- get_run_metrics(run)

Clean up resources

Delete the resources once you no longer need them. Don’t delete any resource you plan to still use.

Delete the compute cluster: