What is sharding in the context of machine learning specifically ( a more generic antic question is asked here ) and how is it implemented in Tensorflow ?
What is referred to as sharding, why do we need sharding altogether, when speaking about the data pipeline in machine learning ?
In Tensorflow -
In Dataset
the function shard()
creates a Dataset that includes only 1/num_shards of this dataset. Shard is deterministic. The Dataset produced by A.shard(n, i) will contain all elements of A whose index mod n = i.
A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0,3,6,9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1,4,7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2,5,8]
Important caveats: Be sure to shard before you use any randomizing operator (such as shuffle).
Generally it is best if the shard operator is used early in the dataset pipeline. For example, when reading from a set of TFRecord files, shard before converting the dataset to input samples. This avoids reading every file on every worker. The following is an example of an efficient sharding strategy within a complete pipeline:
Autosharding a dataset over a set of workers means that each worker is assigned a subset of the entire dataset (if the right tf.data.experimental.AutoShardPolicy is set).
This is to ensure that at each step, a global batch size of non overlapping dataset elements will be processed by each worker.
Setting autosharding options, example
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(64).batch(16)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
There is no autosharding in multi worker training with ParameterServerStrategy
.
Autosharding options are
AUTO: This is the default option which means an attempt will be made to shard by FILE. The attempt to shard by FILE fails if a file-based dataset is not detected.
FILE: This is the option if you want to shard the input files over all the workers. You should use this option if the number of input files is much larger than the number of workers and the data in the files is evenly distributed. For example, let us distribute 2 files over 2 workers with 1 replica each. File 1 contains [0, 1, 2, 3, 4, 5] and File 2 contains [6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2 and global batch size be 4.
Worker 0: Batch 1 = Replica 1: [0, 1] Batch 2 = Replica 1: [2, 3] Batch 3 = Replica 1: [4] Batch 4 = Replica 1: [5] Worker 1: Batch 1 = Replica 2: [6, 7] Batch 2 = Replica 2: [8, 9] Batch 3 = Replica 2: [10] Batch 4 = Replica 2: [11]
DATA: This will autoshard the elements across all the workers. Each of the workers will read the entire dataset and only process the shard assigned to it. All other shards will be discarded. This is generally used if the number of input files is less than the number of workers and you want better sharding of data across all workers. The downside is that the entire dataset will be read on each worker. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2.
Worker 0: Batch 1 = Replica 1: [0, 1] Batch 2 = Replica 1: [4, 5] Batch 3 = Replica 1: [8, 9] Worker 1: Batch 1 = Replica 2: [2, 3] Batch 2 = Replica 2: [6, 7] Batch 3 = Replica 2: [10, 11]
OFF: If you turn off autosharding, each worker will process all the data. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2. Then each worker will see the following distribution:
Worker 0: Batch 1 = Replica 1: [0, 1] Batch 2 = Replica 1: [2, 3] Batch 3 = Replica 1: [4, 5] Batch 4 = Replica 1: [6, 7] Batch 5 = Replica 1: [8, 9] Batch 6 = Replica 1: [10, 11]
Worker 1: Batch 1 = Replica 2: [0, 1] Batch 2 = Replica 2: [2, 3] Batch 3 = Replica 2: [4, 5] Batch 4 = Replica 2: [6, 7] Batch 5 = Replica 2: [8, 9] Batch 6 = Replica 2: [10, 11]