diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index a6716ba40..52644c360 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -17,11 +17,13 @@ package org.deeplearning4j.iterator; +import lombok.Getter; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; +import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider; +import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.junit.Test; -import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -42,8 +44,12 @@ import static org.junit.Assert.*; public class TestBertIterator extends BaseDL4JTest { - private File pathToVocab = Resources.asFile("other/vocab.txt"); + private static File pathToVocab = Resources.asFile("other/vocab.txt"); private static Charset c = StandardCharsets.UTF_8; + private static String shortSentence = "I saw a girl with a telescope."; + private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + private static String sentenceA = "Goodnight noises everywhere"; + private static String sentenceB = "Goodnight moon"; public TestBertIterator() throws IOException { } @@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeaturesMaskArray(0)); - - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + List tokens = testHelper.getTokenizedSentences().get(i); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); - forInference.set(0, toTokenize2); //Same thing, but with segment ID also b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); - //assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null... - assertEquals(2, b.featurizeSentences(forInference).getFirst().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); - assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]); + assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]); } @Test(timeout = 20000L) public void testBertUnsupervised() throws Exception { + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); //Task 1: Unsupervised - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.UNSUPERVISED) .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5)) .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX) .maskToken("[MASK]") .build(); - System.out.println("Mask token index: " + t.getVocab().get("[MASK]")); + System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]")); MultiDataSet mds = b.next(); System.out.println(mds.getFeatures(0)); @@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getLabels(0)); System.out.println(mds.getLabelsMaskArray(0)); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); @@ -164,40 +159,34 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testLengthHandling() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - System.out.println(tokens); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - System.out.println(tokens2); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - //-------------------------------------------------------------- //Fixed length: clip or pad - already tested in other tests @@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest { //Any length: as long as we need to fit longest sequence BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); MultiDataSet mds = b.next(); @@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0)); - assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); //Clip only: clip to maximum, but don't pad if less b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); - mds = b.next(); expShape = new long[]{2, 14}; assertArrayEquals(expShape, mds.getFeatures(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); @@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testMinibatchPadding() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String toTokenize3 = "Goodnight noises everywhere"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - forInference.add(toTokenize3); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); - } - - INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM3 = Nd4j.create(DataType.INT, 1, 16); - List tokens3 = t.create(toTokenize3).getTokens(); - for (int i = 0; i < tokens3.size(); i++) { - String token = tokens3.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx3.putScalar(0, i, idx); - expM3.putScalar(0, i, 1); - } - + int minibatchSize = 3; + TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize); INDArray zeros = Nd4j.create(DataType.INT, 1, 16); - INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); - INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); - INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF.dup(), expFTemp); + expM = Nd4j.vstack(expM.dup(), expMTemp); + } + } + + expF = Nd4j.vstack(expF, zeros); + expM = Nd4j.vstack(expM, zeros); + INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}}); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); @@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest { //-------------------------------------------------------------- BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(4) + .minibatchSize(minibatchSize + 1) .padMinibatches(true) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expL, mds.getLabels(0)); assertEquals(expLM, mds.getLabelsMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); } + /* + Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair + Checks different lengths for max length to check popping and padding + */ @Test public void testSentencePairsSingle() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; boolean prependAppend; - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); + int numOfSentences; + + TestSentenceHelper testHelper = new TestSentenceHelper(); + int shortL = testHelper.getShortestL(); + int longL = testHelper.getLongestL(); Triple multiDataSetTriple; - MultiDataSet shortLongPair, shortSentence, longSentence; + MultiDataSet fromPair, leftSide, rightSide; // check for pair max length exactly equal to sum of lengths - pop neither no padding // should be the same as hstack with segment ids 1 for second sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length greater than sum of lengths - pop neither with padding // features should be the same as hstack of shorter and longer padded with prepend/append // segment id should 1 only in the longer for part of the length of the sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length less than shorter sentence - pop both //should be the same as hstack with segment ids 1 for second sentence if no prepend/append - int maxL = shortL - 2; + int maxL = 5;//checking odd + numOfSentences = 3; prependAppend = false; - multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); } + /* + Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs + Checks various max lengths + Has sentences of varying lengths + */ @Test public void testSentencePairsUnequalLengths() throws IOException { - //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row - //should be identical to hstack if there is no append, prepend - //batch size is 2 - int mbS = 4; - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping - String sent2 = "Goodnight moon"; //shorter than shortSent - no popping - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); - int sent1L = t.create(sent1).countTokens(); - int sent2L = t.create(sent2).countTokens(); - //won't check 2*shortL + 1 because this will always pop on the left - for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) { + + int minibatchSize = 4; + int numOfSentencesinIter = 3; + + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter); + int shortL = testPairHelper.getShortL(); + int longL = testPairHelper.getLongL(); + int sent1L = testPairHelper.getSentenceALen(); + int sent2L = testPairHelper.getSentenceBLen(); + + System.out.println("Sentence Pairs, Left"); + System.out.println(testPairHelper.getSentencesLeft()); + System.out.println("Sentence Pairs, Right"); + System.out.println(testPairHelper.getSentencesRight()); + + //anything outside this range more will need to check padding,truncation + for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) { + + System.out.println("Running for max length = " + maxL); + MultiDataSet leftMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet rightMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider(true)) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet pairMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either - .sentencePairProvider(new TestSentencePairProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .padMinibatches(true) .build().next(); - //Left sentences here are {{shortSent}, - // {longSent}, - // {Sent1}} - //Right sentences here are {{longSent}, - // {shortSent}, - // {Sent2}} - //The sentence pairs here are {{shortSent,longSent}, - // {longSent,shortSent} - // {Sent1, Sent2}} - //CHECK FEATURES - INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); + INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL); //left side INDArray leftFeatures = leftMDS.getFeatures(0); INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); - INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); + INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L)); //right side INDArray rightFeatures = rightMDS.getFeatures(0); INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); - INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); + INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L)); //expected pair - combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); - combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); - combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); + combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat)); + combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat)); + combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat)); assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); assertEquals(combinedFeat, pairMDS.getFeatures(0)); //CHECK SEGMENT ID - INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); + INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL); combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); - combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1); assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); assertEquals(maxL, combinedFetSeg.shape()[1]); assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); + + testPairHelper.getPairSentenceProvider().reset(); } } @Test public void testSentencePairFeaturizer() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List> listSentencePair = new ArrayList<>(); - listSentencePair.add(new Pair<>(shortSent, longSent)); - listSentencePair.add(new Pair<>(longSent, shortSent)); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int minibatchSize = 2; + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize); BertIterator b = BertIterator.builder() - .tokenizer(t) - .minibatchSize(2) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .padMinibatches(true) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .prependToken("[CLS]") .appendToken("[SEP]") .build(); @@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest { INDArray[] featuresArr = mds.getFeatures(); INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); - Pair p = b.featurizeSentencePairs(listSentencePair); + Pair p = b.featurizeSentencePairs(testPairHelper.getSentencePairs()); assertEquals(p.getFirst().length, 2); assertEquals(featuresArr[0], p.getFirst()[0]); assertEquals(featuresArr[1], p.getFirst()[1]); - //assertEquals(p.getSecond().length, 2); assertEquals(featuresMaskArr[0], p.getSecond()[0]); - //assertEquals(featuresMaskArr[1], p.getSecond()[1]); } /** - * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append + * Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator + * with given max lengths and whether to prepend/append * Idea is the sentence pair dataset can be constructed from the single sentence datasets - * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" - * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope." - * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" */ - private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException { + private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend, int numSentences) throws IOException { BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); int maxforPair = maxLengths.getFirst(); int maxPartOne = maxLengths.getSecond(); @@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest { BertIterator.Builder commonBuilder; commonBuilder = BertIterator.builder() .tokenizer(t) - .minibatchSize(1) + .minibatchSize(4) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .vocabMap(t.getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION); - BertIterator shortLongPairFirstIter = commonBuilder + BertIterator pairIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator shortFirstIter = commonBuilder + BertIterator leftIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator longFirstIter = commonBuilder + BertIterator rightIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) - .sentenceProvider(new TestSentenceProvider(true)) + .sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider()) .prependToken(null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); + return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next()); } - private static class TestSentenceProvider implements LabeledSentenceProvider { + @Getter + private static class TestSentencePairsHelper { - private int pos = 0; - private boolean invert; + private List sentencesLeft; + private List sentencesRight; + private List> sentencePairs; + private List> tokenizedSentencesLeft; + private List> tokenizedSentencesRight; + private List labels; + private int shortL; + private int longL; + private int sentenceALen; + private int sentenceBLen; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledPairSentenceProvider pairSentenceProvider; - private TestSentenceProvider() { - this.invert = false; + private TestSentencePairsHelper() throws IOException { + this(3); } - private TestSentenceProvider(boolean invert) { - this.invert = invert; - } - - @Override - public boolean hasNext() { - return pos < totalNumSentences(); - } - - @Override - public Pair nextSentence() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); - return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - } else { - if (pos == 1) { - pos++; - if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - return new Pair<>("I saw a girl with a telescope.", "positive"); + private TestSentencePairsHelper(int minibatchSize) throws IOException { + sentencesLeft = new ArrayList<>(); + sentencesRight = new ArrayList<>(); + sentencePairs = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentencesLeft = new ArrayList<>(); + tokenizedSentencesRight = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + sentencesLeft.add(shortSentence); + sentencesRight.add(longSentence); + sentencePairs.add(new Pair<>(shortSentence, longSentence)); + labels.add("positive"); + if (minibatchSize > 1) { + sentencesLeft.add(longSentence); + sentencesRight.add(shortSentence); + sentencePairs.add(new Pair<>(longSentence, shortSentence)); + labels.add("negative"); + if (minibatchSize > 2) { + sentencesLeft.add(sentenceA); + sentencesRight.add(sentenceB); + sentencePairs.add(new Pair<>(sentenceA, sentenceB)); + labels.add("positive"); } - pos++; - if (!invert) - return new Pair<>("Goodnight noises everywhere", "positive"); - return new Pair<>("Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < minibatchSize; i++) { + List tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens(); + List tokensR = tokenizer.create(sentencesRight.get(i)).getTokens(); + if (i == 0) { + shortL = tokensL.size(); + longL = tokensR.size(); + } + if (i == 2) { + sentenceALen = tokensL.size(); + sentenceBLen = tokensR.size(); + } + tokenizedSentencesLeft.add(tokensL); + tokenizedSentencesRight.add(tokensR); + } + pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null); } } - private static class TestSentencePairProvider implements LabeledPairSentenceProvider { + @Getter + private static class TestSentenceHelper { - private int pos = 0; + private List sentences; + private List> tokenizedSentences; + private List labels; + private int shortestL = 0; + private int longestL = 0; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledSentenceProvider sentenceProvider; - @Override - public boolean hasNext() { - return pos < totalNumSentences(); + private TestSentenceHelper() throws IOException { + this(false, 2); } - @Override - public Triple nextSentencePair() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); - } else { - if (pos == 1) { - pos++; - return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); + private TestSentenceHelper(int minibatchSize) throws IOException { + this(false, minibatchSize); + } + + private TestSentenceHelper(boolean alternateOrder) throws IOException { + this(false, 3); + } + + private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException { + sentences = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentences = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + if (!alternateOrder) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 1) { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 2) { + sentences.add(sentenceA); + labels.add("positive"); + } + } + } else { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 1) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 2) { + sentences.add(sentenceB); + labels.add("positive"); + } } - pos++; - return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < sentences.size(); i++) { + List tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens(); + if (i == 0) + shortestL = tokenizedSentence.size(); + if (tokenizedSentence.size() > longestL) + longestL = tokenizedSentence.size(); + if (tokenizedSentence.size() < shortestL) + shortestL = tokenizedSentence.size(); + tokenizedSentences.add(tokenizedSentence); + } + sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null); } }