diff --git a/trax/data/inputs.py b/trax/data/inputs.py index a2f2035ee..bcccde901 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -681,7 +681,7 @@ def batch(generator, batch_size): # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) try: - batched_example = tuple(np.stack(x) for x in zip(*buf)) + batched_example = tuple(pad_to_max_dims([np.asarray(tensor) for tensor in x]) for x in zip(*buf)) except ValueError as e: for j in range(len(buf)): logging.error('Batch[%d][%d] input shape: %r output shape: %r', diff --git a/trax/data/inputs_test.py b/trax/data/inputs_test.py index 7181589e1..05e888b97 100644 --- a/trax/data/inputs_test.py +++ b/trax/data/inputs_test.py @@ -83,6 +83,13 @@ def test_batch_data(self): self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10,)) + def test_batch_data_padding(self): + dataset = (([1] * (10 - i), i+1) for i in range(10)) + batches = data.batch(dataset, 10) + batch = next(batches) + self.assertEqual(batch[0].shape, (10, 10)) + self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0]))) + def test_batch_exception_size(self): dataset = ((i, i + 1) for i in range(10)) with self.assertRaises(ValueError):