I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).
import tensorflow as tf
tf.compat.as_str(
bytes_or_text='str', encoding='utf-8'
)
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset('/test.tfrecord')
dataset = dataset.with_options(ignore_order)
features, feature_lists = detect_schema(dataset)
#Decodings TFRecord serialized data
def decode_data(serialized):
X, y = tf.io.parse_single_sequence_example(
serialized,
context_features=features,
sequence_features=feature_lists)
return X['title'], y['subject']
dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))
#Filtering and concatenating the samples
def balanced_dataset(dataset, labels_list, sample_size=1000):
datasets_list = []
for label in labels_list:
#Filtering the chosen labels
locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
#appending a limited sample
datasets_list.append(locals()[label].take(sample_size))
concat_dataset = datasets_list[0]
#concatenating the datasets
for dset in datasets_list[1:]:
concat_dataset = concat_dataset.concatenate(dset)
return concat_dataset
balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)
One way to solve this is by using group_by_window
method where the window_size
would be the sample size
of each class (in your case 1k).
ds = ds.group_by_window(
# Use label as key
key_func=lambda _, l: l,
# Convert each window to a sample_size
reduce_func=lambda _, window: window.batch(sample_size),
# Use window size as sample_size
window_size=sample_size)
This will form batches of single classes of size sample_size
. But there is one problem, there will be multiple batches of same class, but you just need one of the batches in each class.
To solve the above problem, we need to add a count
for each of the batches and then filter out count==0
, which will fetch the first batch of all the classes.
Lets define an example:
labels = np.array(sum([[label]*repeat for label, repeat in zip([0, 1, 2], [100, 200, 15])], []))
features = np.arange(len(labels))
np.unique(labels, return_counts=True)
#(array([0, 1, 2]), array([100, 200, 15]))
# There are 3 labels chosen for simplicity and each of their counts are shown along.
sample_size = 15 # we choose to pick sample of 15 from each class
We create a dataset from the above inputs,
ds = tf.data.Dataset.from_tensor_slices((features, labels))
In the above window function we modify the reduce_func
to make the counter, so the batch will have 3 elements (X_batch, y_batch, label_counter)
:
def reduce_func(x, y):
#class_count[y] += 1
z = table.lookup(x)
table.insert(x, z+1)
return y.batch(sample_size).map(lambda a,b: (a, b, z))
# Group by window
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.group_by_window(
# Use label as key
key_func=lambda _, l: l,
# Convert each window to a sample_size
reduce_func=reduce_func,
# Use window size as sample_size
window_size=sample_size)
The counter
logic in reduce_func
is implemented as a table lookup
where the counter needs to be updated and read from a lookup table. Its initialized as shown below:
n_classes = 3
keys = tf.range(0,n_classes, dtype=tf.int64)
vals = tf.zeros_like(keys, dtype=tf.int64)
table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.int64,
value_dtype=tf.int64,
default_value=-1)
table.insert(keys, vals)
Now we filter out the batch where the count==0
and remove the count element to form (X, y) batch pairs:
ds = ds.filter(lambda x, y, count: count==0)
ds = ds.map(lambda x, y, count: (x, y))
Output,
for x, y in ds:
print(x.numpy(), y.numpy())
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[100 101 102 103 104 105 106 107 108 109 110 111 112 113 114] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[300 301 302 303 304 305 306 307 308 309 310 311 312 313 314] [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]