Distributed Training and Inference¶
Orca Estimator
provides sklearn-style APIs for transparently distributed model training and inference
1. Estimator¶
To perform distributed training and inference, the user can first create an Orca Estimator
from any standard (single-node) TensorFlow, Kera or PyTorch model, and then call Estimator.fit
or Estimator.predict
methods (using the data-parallel processing pipeline as input).
Under the hood, the Orca Estimator
will replicate the model on each node in the cluster, feed the data partition (generated by the data-parallel processing pipeline) on each node to the local model replica, and synchronize model parameters using various backend technologies (such as Horovod, tf.distribute.MirroredStrategy
, torch.distributed
, or the parameter sync layer in BigDL).
2. TensorFlow/Keras Estimator¶
2.1 TensorFlow 1.15 and Keras 2.3¶
There are two ways to create an Estimator for TensorFlow 1.15, either from a low level computation graph or a Keras model. Examples are as follow:
TensorFlow Computation Graph:
# define inputs to the graph
images = tf.placeholder(dtype=tf.float32, shape=(None, 28, 28, 1))
labels = tf.placeholder(dtype=tf.int32, shape=(None,))
# define the network and loss
logits = lenet(images)
loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels))
# define a metric
acc = accuracy(logits, labels)
# create an estimator using endpoints of the graph
est = Estimator.from_graph(inputs=images,
outputs=logits,
labels=labels,
loss=loss,
optimizer=tf.train.AdamOptimizer(),
metrics={"acc": acc})
Keras Model:
model = create_keras_lenet_model()
model.compile(optimizer=keras.optimizers.RMSprop(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
est = Estimator.from_keras(keras_model=model)
Then users can perform distributed model training and inference as follows:
dataset = tfds.load(name="mnist", split="train")
dataset = dataset.map(preprocess)
est.fit(data=mnist_train,
batch_size=320,
epochs=max_epoch)
predictions = est.predict(data=df,
feature_cols=['image'])
The data
argument in fit
method can be a Spark DataFrame, an XShards or a tf.data.Dataset
. The data
argument in predict
method can be a spark DataFrame or an XShards. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
2.2 TensorFlow 2.x and Keras 2.4+¶
Users can create an Estimator
for TensorFlow 2.x from a Keras model (using a Model Creator Function). For example:
def model_creator(config):
model = create_keras_lenet_model()
model.compile(optimizer=keras.optimizers.RMSprop(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
est = Estimator.from_keras(model_creator=model_creator)
The model_creator
argument should be a function that takes a config
dictionary and returns a compiled Keras model.
Then users can perform distributed model training and inference as follows:
def train_data_creator(config, batch_size):
dataset = tfds.load(name="mnist", split="train")
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch_size)
return dataset
stats = est.fit(data=train_data_creator,
epochs=max_epoch,
steps_per_epoch=total_size // batch_size)
predictions = est.predict(data=df,
feature_cols=['image'])
The data
argument in fit
method can be a spark DataFrame, an XShards or a Data Creator Function (that returns a tf.data.Dataset
). The data
argument in predict
method can be a spark DataFrame or an XShards. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
For more details, view the distributed TensorFlow training/inference page.
3. PyTorch Estimator¶
Using BigDL backend
Users may create a PyTorch Estimator
using the BigDL backend (currently default for PyTorch) as follows:
model = LeNet() # a torch.nn.Module
model.train()
criterion = nn.NLLLoss()
adam = torch.optim.Adam(model.parameters(), args.lr)
est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion)
Then users can perform distributed model training and inference as follows:
est.fit(data=train_loader, epochs=args.epochs)
predictions = est.predict(xshards)
The input to fit
methods can be a torch.utils.data.DataLoader
, a Spark Dataframe, an XShards, or a Data Creator Function (that returns a torch.utils.data.DataLoader
). The input to predict
methods should be a Spark Dataframe, or an XShards. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
Using torch.distributed
or Horovod backend
Alternatively, users can create a PyTorch Estimator
using torch.distributed
or Horovod backend by specifying the backend
argument to be “torch_distributed” or “horovod”. In this case, the model
and optimizer
should be wrapped in Creater Functions. For example:
def model_creator(config):
model = LeNet() # a torch.nn.Module
model.train()
return model
def optimizer_creator(model, config):
return torch.optim.Adam(model.parameters(), config["lr"])
est = Estimator.from_torch(model=model,
optimizer=optimizer_creator,
loss=nn.NLLLoss(),
config={"lr": 1e-2},
backend="torch_distributed") # or backend="horovod"
Then users can perform distributed model training and inference as follows:
est.fit(data=train_loader_func, epochs=args.epochs)
predictions = est.predict(data=df,
feature_cols=['image'])
The input to fit
methods can be a Spark DataFrame, an XShards, or a Data Creator Function (that returns a torch.utils.data.DataLoader
). The data
argument in predict
method can be a Spark DataFrame or an XShards. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
For more details, view the distributed PyTorch training/inference page<TODO: link to be added>.
4. MXNet Estimator¶
The user may create a MXNet Estimator
as follows:
from bigdl.orca.learn.mxnet import Estimator, create_config
def get_model(config):
net = LeNet() # a mxnet.gluon.Block
return net
def get_loss(config):
return gluon.loss.SoftmaxCrossEntropyLoss()
config = create_config(log_interval=2, optimizer="adam",
optimizer_params={'learning_rate': 0.02})
est = Estimator.from_mxnet(config=config,
model_creator=get_model,
loss_creator=get_loss,
num_workers=2)
Then the user can perform distributed model training as follows:
import numpy as np
def get_train_data_iter(config, kv):
train = mx.io.NDArrayIter(data_ndarray, label_ndarray,
batch_size=config["batch_size"], shuffle=True)
return train
est.fit(get_train_data_iter, epochs=2)
The input to fit
methods can be an XShards, or a Data Creator Function (that returns an MXNet DataIter/DataLoader
). See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
5. BigDL Estimator¶
The user may create a BigDL Estimator
as follows:
from bigdl.dllib.nn.criterion import *
from bigdl.dllib.nn.layer import *
from bigdl.dllib.optim.optimizer import *
from bigdl.orca.learn.bigdl import Estimator
linear_model = Sequential().add(Linear(2, 2))
mse_criterion = MSECriterion()
est = Estimator.from_bigdl(model=linear_model, loss=mse_criterion, optimizer=Adam())
Then the user can perform distributed model training and inference as follows:
# read spark Dataframe
df = spark.read.parquet("data.parquet")
# distributed model training
est.fit(df, 1, batch_size=4)
#distributed model inference
result_df = est.predict(df)
The input to fit
and predict
methods can be a Spark Dataframe, or an XShards. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.
6. OpenVINO Estimator¶
The user may create a OpenVINO Estimator
as follows:
from bigdl.orca.learn.openvino import Estimator
model_path = "The/file_path/to/the/OpenVINO_IR_xml_file"
est = Estimator.from_openvino(model_path=model_path)
Then the user can perform distributed model inference as follows:
# ndarray
input_data = np.random.random([20, 4, 3, 224, 224])
result = est.predict(input_data)
# xshards
shards = XShards.partition({"x": input_data})
result_shards = est.predict(shards)
The input to predict
methods can be an XShards, or a numpy array. See the data-parallel processing pipeline page for more details.
View the related Python API doc for more details.