分布式训练

相关概念

  • 客户端 (Client): 客户端是一个用于建立 TensorFlow 计算图并创立与集群进行交互的会话层 tensorflow::Session 的程序。一般客户端是通过 python 或 C++ 实现的。一个独立的客户端进程可以同时与多个 TensorFlow 的服务端相连 (上面的计算流程一节),同时一个独立的服务端也可以与多个客户端相连。
  • 集群 (Cluster) : 一个 TensorFlow 的集群里包含了一个或多个作业 (job), 每一个作业又可以拆分成一个或多个任务 (task)。集群的概念主要用与一个特定的高层次对象中,比如说训练神经网络,并行化操作多台机器等等。集群对象可以通过 tf.train.ClusterSpec 来定义。
  • 作业 (Job) : 一个作业可以拆封成多个具有相同目的的任务(task),比如说,一个称之为 ps(parameter server,参数服务器) 的作业中的任务主要是保存和更新变量,而一个名为 work(工作)的作业一般是管理无状态且主要从事计算的任务。一个作业中的任务可以运行于不同的机器上,作业的角色也是灵活可变的,比如说称之为”work”的作业可以保存一些状态。
  • 主节点的服务逻辑 (Master service) : 一个 RPC 服务程序可以用来远程连接一系列的分布式设备,并扮演一个会话终端的角色,主服务程序实现了一个 tensorflow::Session 的借口并负责通过工作节点的服务进程(worker service) 与工作的任务进行通信。所有的主服务程序都有了主节点的服务逻辑。
  • 任务 (Task) : 任务相当于是一个特定的 TesnsorFlow 服务端,其相当于一个独立的进程,该进程属于特定的作业并在作业中拥有对应的序号。
  • TensorFlow 服务端 (TensorFlow server) : 一个运行了 tf.train.Server 实例的进程,其为集群中的一员,并有主节点和工作节点之分。
  • 工作节点的服务逻辑 (Worker service) : 其为一个可以使用本地设备对部分图进行计算的 RPC 逻辑,一个工作节点的服务逻辑实现了 worker_service.proto 接口, 所有的 TensorFlow 服务端均包含工作节点的服务逻辑。

创建 Cluster 集群

cluster = tf.train.ClusterSpec({
    "worker": [
        "worker0.example.com:2222",
        "worker1.example.com:2222",
        "worker2.example.com:2222"
    ],
    "ps": [
        "ps0.example.com:2222",
        "ps1.example.com:2222"
    ]})

# Specifying distributed devices in your model
with tf.device("/job:ps/task:0"):
  weights_1 = tf.Variable(...)
  biases_1 = tf.Variable(...)

with tf.device("/job:ps/task:1"):
  weights_2 = tf.Variable(...)
  biases_2 = tf.Variable(...)

with tf.device("/job:worker/task:7"):
  input, labels = ...
  layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
  logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
  # ...
  train_op = ...

with tf.Session("grpc://worker7.example.com:2222") as sess:
  for _ in range(10000):
    sess.run(train_op)

注意: cluster spec 需要在进程启动时指定,无法实现动态的扩容或缩容,这个问题社区通过引入 Kubernetes 集群管理工具来解决

LocalServer

客户端与服务器一起的版本:

# Start a TensorFlow server as a single-process "cluster".
import tensorflow as tf
c = tf.constant("Hello, distributed TensorFlow!")
server = tf.train.Server.create_local_server()
sess = tf.Session(server.target)  # Create a session on the server.
print sess.run(c)
'Hello, distributed TensorFlow!'

当然,可以将其拆分成 c/s 独立的服务:

# server
import tensorflow as tf
server = tf.train.Server.create_local_server()
server.join()

# client
import tensorflow as tf
c = tf.constant("Hello, distributed TensorFlow!")
sess = tf.Session("grpc://localhost:46685")
print sess.run(c)
'Hello, distributed TensorFlow!'

完整实例

import tensorflow as tf

# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")

# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "","One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

FLAGS = tf.app.flags.FLAGS

def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

      # Build model...
      loss = ...
      global_step = tf.Variable(0)

      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process.
    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir="/tmp/train_logs",
                             init_op=init_op,
                             summary_op=summary_op,
                             saver=saver,
                             global_step=global_step,
                             save_model_secs=600)

    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target) as sess:
      # Loop until the supervisor shuts down or 1000000 steps have completed.
      step = 0
      while not sv.should_stop() and step < 1000000:
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.
        _, step = sess.run([train_op, global_step])

    # Ask for all the services to stop.
    sv.stop()

if __name__ == "__main__":
  tf.app.run()
# On ps0.example.com:
$ python trainer.py \
     --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
     --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
     --job_name=ps --task_index=0
# On ps1.example.com:
$ python trainer.py \
     --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
     --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
     --job_name=ps --task_index=1
# On worker0.example.com:
$ python trainer.py \
     --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
     --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
     --job_name=worker --task_index=0
# On worker1.example.com:
$ python trainer.py \
     --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
     --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
     --job_name=worker --task_index=1
© Pengfei Ni all right reserved,powered by GitbookUpdated at 2018-03-02 14:52:02

results matching ""

    No results matching ""