takeWhile() working differently with flatmap

JavaLambdaJava StreamJava 9

Java Problem Overview


I am creating snippets with takeWhile to explore its possibilities. When used in conjunction with flatMap, the behaviour is not in line with the expectation. Please find the code snippet below.

String[][] strArray = {{"Sample1", "Sample2"}, {"Sample3", "Sample4", "Sample5"}};

Arrays.stream(strArray)
        .flatMap(indStream -> Arrays.stream(indStream))
        .takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"))
        .forEach(ele -> System.out.println(ele));

Actual Output:

Sample1
Sample2
Sample3
Sample5

ExpectedOutput:

Sample1
Sample2
Sample3

Reason for the expectation is that takeWhile should be executing till the time the condition inside turns true. I have also added printout statements inside flatmap for debugging. The streams are returned just twice which is inline with the expectation.

However, this works just fine without flatmap in the chain.

String[] strArraySingle = {"Sample3", "Sample4", "Sample5"};
Arrays.stream(strArraySingle)
        .takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"))
        .forEach(ele -> System.out.println(ele));

Actual Output:

Sample3

Here the actual output matches with the expected output.

Disclaimer: These snippets are just for code practise and does not serve any valid usecases.

Update: Bug JDK-8193856: fix will be available as part of JDK 10. The change will be to correct whileOps Sink::accept

@Override 
public void accept(T t) {
    if (take = predicate.test(t)) {
        downstream.accept(t);
    }
}

Changed Implementation:

@Override
public void accept(T t) {
    if (take && (take = predicate.test(t))) {
        downstream.accept(t);
    }
}

Java Solutions


Solution 1 - Java

This is a bug in JDK 9 - from issue #8193856:

> takeWhile is incorrectly assuming that an upstream operation supports and honors cancellation, which unfortunately is not the case for flatMap.

Explanation

If the stream is ordered, takeWhile should show the expected behavior. This is not entirely the case in your code because you use forEach, which waives order. If you care about it, which you do in this example, you should use forEachOrdered instead. Funny thing: That doesn't change anything. 樂

So maybe the stream isn't ordered in the first place? (In that case the behavior is ok.) If you create a temporary variable for the stream created from strArray and check whether it is ordered by executing the expression ((StatefulOp) stream).isOrdered(); at the breakpoint, you will find that it is indeed ordered:

String[][] strArray = {{"Sample1", "Sample2"}, {"Sample3", "Sample4", "Sample5"}};

Stream<String> stream = Arrays.stream(strArray)
		.flatMap(indStream -> Arrays.stream(indStream))
		.takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"));

// breakpoint here
System.out.println(stream);

That means that this is very likely an implementation error.

Into The Code

As others have suspected, I now also think that this might be connected to flatMap being eager. More precisely, both problems might have the same root cause.

Looking into the source of WhileOps, we can see these methods:

@Override
public void accept(T t) {
	if (take = predicate.test(t)) {
		downstream.accept(t);
	}
}

@Override
public boolean cancellationRequested() {
	return !take || downstream.cancellationRequested();
}

This code is used by takeWhile to check for a given stream element t whether the predicate is fulfilled:

  • If so, it passes the element on to the downstream operation, in this case System.out::println.
  • If not, it sets take to false, so when it is asked next time whether the pipeline should be canceled (i.e. it is done), it returns true.

This covers the takeWhile operation. The other thing you need to know is that forEachOrdered leads to the terminal operation executing the method ReferencePipeline::forEachWithCancel:

@Override
final boolean forEachWithCancel(Spliterator<P_OUT> spliterator, Sink<P_OUT> sink) {
	boolean cancelled;
	do { } while (
			!(cancelled = sink.cancellationRequested())
			&& spliterator.tryAdvance(sink));
	return cancelled;
}

All this does is:

  1. check whether pipeline was canceled
  2. if not, advance the sink by one element
  3. stop if this was the last element

Looks promising, right?

Without flatMap

In the "good case" (without flatMap; your second example) forEachWithCancel directly operates on the WhileOp as sink and you can see how this plays out:

  • ReferencePipeline::forEachWithCancel does its loop:
    • WhileOps::accept is given each stream element
    • WhileOps::cancellationRequested is queried after each element
  • at some point "Sample4" fails the predicate and the stream is canceled

Yay!

With flatMap

In the "bad case" (with flatMap; your first example), forEachWithCancel operates on the flatMap operation, though, , which simply calls forEachRemaining on the ArraySpliterator for {"Sample3", "Sample4", "Sample5"}, which does this:

if ((a = array).length >= (hi = fence) &&
	(i = index) >= 0 && i < (index = hi)) {
	do { action.accept((T)a[i]); } while (++i < hi);
}

Ignoring all that hi and fence stuff, which is only used if the array processing is split for a parallel stream, this is a simple for loop, which passes each element to the takeWhile operation, but never checks whether it is cancelled. It will hence eagerly ply through all elements in that "substream" before stopping, likely even through the rest of the stream.

Solution 2 - Java

This is a bug no matter how I look at it - and thank you Holger for your comments. I did not want to put this answer in here (seriously!), but none of the answer clearly states that this is a bug.

People are saying that this has to with ordered/un-ordered, and this is not true as this will report true 3 times:

Stream<String[]> s1 = Arrays.stream(strArray);
System.out.println(s1.spliterator().hasCharacteristics(Spliterator.ORDERED));

Stream<String> s2 = Arrays.stream(strArray)
            .flatMap(indStream -> Arrays.stream(indStream));
System.out.println(s2.spliterator().hasCharacteristics(Spliterator.ORDERED));

Stream<String> s3 = Arrays.stream(strArray)
            .flatMap(indStream -> Arrays.stream(indStream))
            .takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"));
System.out.println(s3.spliterator().hasCharacteristics(Spliterator.ORDERED));

It's very interesting also that if you change it to:

String[][] strArray = { 
         { "Sample1", "Sample2" }, 
         { "Sample3", "Sample5", "Sample4" }, // Sample4 is the last one here
         { "Sample7", "Sample8" } 
};

then Sample7 and Sample8 will not be part of the output, otherwise they will. It seems that flatmap ignores a cancel flag that would be introduced by dropWhile.

Solution 3 - Java

If you look at the documentation for takeWhile:

> if this stream is ordered, [returns] a stream consisting of the > longest prefix of elements taken from this stream that match the given > predicate. > > if this stream is unordered, [returns] a stream consisting of a subset > of elements taken from this stream that match the given predicate.

Your stream is coincidentally ordered, but takeWhile doesn't know that it is. As such, it is returning 2nd condition - the subset. Your takeWhile is just acting like a filter.

If you add a call to sorted before takeWhile, you'll see the result you expect:

Arrays.stream(strArray)
      .flatMap(indStream -> Arrays.stream(indStream))
      .sorted()
      .takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"))
      .forEach(ele -> System.out.println(ele));

Solution 4 - Java

The reason for that is the flatMap operation also being an intermediate operations with which (one of) the stateful short-circuiting intermediate operation takeWhile is used.

The behavior of flatMap as pointed by Holger in this answer is certainly a reference one shouldn't miss out to understand the unexpected output for such short-circuiting operations.

Your expected result can be achieved by splitting these two intermediate operations by introducing a terminal operation to deterministically use an ordered stream further and performing them for a sample as :

List<String> sampleList = Arrays.stream(strArray).flatMap(Arrays::stream).collect(Collectors.toList());
sampleList.stream().takeWhile(ele -> !ele.equalsIgnoreCase("Sample4"))
            .forEach(System.out::println);

Also, there seems to be a related Bug#JDK-8075939 to trace this behavior already registered.

Edit: This can be tracked further at JDK-8193856 accepted as a bug.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionJeevan VarugheseView Question on Stackoverflow
Solution 1 - JavaNicolai ParlogView Answer on Stackoverflow
Solution 2 - JavaEugeneView Answer on Stackoverflow
Solution 3 - JavaMichaelView Answer on Stackoverflow
Solution 4 - JavaNamanView Answer on Stackoverflow