Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 3 additions & 30 deletions src/libraries/System.Linq/src/System/Linq/Skip.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,9 @@ public static IEnumerable<TSource> SkipLast<TSource>(this IEnumerable<TSource> s

return count <= 0 ?
source.Skip(0) :
SkipLastIterator(source, count);
}

private static IEnumerable<TSource> SkipLastIterator<TSource>(IEnumerable<TSource> source, int count)
{
Debug.Assert(source != null);
Debug.Assert(count > 0);

var queue = new Queue<TSource>();

using (IEnumerator<TSource> e = source.GetEnumerator())
{
while (e.MoveNext())
{
if (queue.Count == count)
{
do
{
yield return queue.Dequeue();
queue.Enqueue(e.Current);
}
while (e.MoveNext());
break;
}
else
{
queue.Enqueue(e.Current);
}
}
}
TakeRangeFromEndIterator(source,
isStartIndexFromEnd: false, startIndex: 0,
isEndIndexFromEnd: true, endIndex: count);
}
}
}
25 changes: 25 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Take.SizeOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,30 @@ private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> s
if (--count == 0) break;
}
}

private static IEnumerable<TSource> TakeRangeIterator<TSource>(IEnumerable<TSource> source, int startIndex, int endIndex)
{
Debug.Assert(source != null);
Debug.Assert(startIndex >= 0 && startIndex < endIndex);

using IEnumerator<TSource> e = source.GetEnumerator();

int index = 0;
while (index < startIndex && e.MoveNext())
{
++index;
}

if (index < startIndex)
{
yield break;
}

while (index < endIndex && e.MoveNext())
{
yield return e.Current;
++index;
}
}
}
}
32 changes: 28 additions & 4 deletions src/libraries/System.Linq/src/System/Linq/Take.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,38 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;

namespace System.Linq
{
public static partial class Enumerable
{
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count) =>
source is IPartition<TSource> partition ? partition.Take(count) :
source is IList<TSource> sourceList ? (IEnumerable<TSource>)new ListPartition<TSource>(sourceList, 0, count - 1) :
new EnumerablePartition<TSource>(source, 0, count - 1);
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
{
Debug.Assert(source != null);
Debug.Assert(count > 0);

return
source is IPartition<TSource> partition ? partition.Take(count) :
source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
new EnumerablePartition<TSource>(source, 0, count - 1);
}

private static IEnumerable<TSource> TakeRangeIterator<TSource>(IEnumerable<TSource> source, int startIndex, int endIndex)
{
Debug.Assert(source != null);
Debug.Assert(startIndex >= 0 && startIndex < endIndex);

return
source is IPartition<TSource> partition ? TakePartitionRange(partition, startIndex, endIndex) :
source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, startIndex, endIndex - 1) :
new EnumerablePartition<TSource>(source, startIndex, endIndex - 1);

static IPartition<TSource> TakePartitionRange(IPartition<TSource> partition, int startIndex, int endIndex)
{
partition = endIndex == 0 ? EmptyPartition<TSource>.Instance : partition.Take(endIndex);
return startIndex == 0 ? partition : partition.Skip(startIndex);
}
Comment thread
eiriktsarpalis marked this conversation as resolved.
}
}
}
195 changes: 81 additions & 114 deletions src/libraries/System.Linq/src/System/Linq/Take.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,121 +55,129 @@ public static IEnumerable<TSource> Take<TSource>(this IEnumerable<TSource> sourc
{
return startIndex >= endIndex
? Empty<TSource>()
: source.Skip(startIndex).Take(endIndex - startIndex);
: TakeRangeIterator(source, startIndex, endIndex);
}

return TakeIterator(source, isStartIndexFromEnd, startIndex, isEndIndexFromEnd, endIndex);
return TakeRangeFromEndIterator(source, isStartIndexFromEnd, startIndex, isEndIndexFromEnd, endIndex);
}

private static IEnumerable<TSource> TakeIterator<TSource>(
IEnumerable<TSource> source, bool isStartIndexFromEnd, int startIndex, bool isEndIndexFromEnd, int endIndex)
private static IEnumerable<TSource> TakeRangeFromEndIterator<TSource>(IEnumerable<TSource> source, bool isStartIndexFromEnd, int startIndex, bool isEndIndexFromEnd, int endIndex)
{
Debug.Assert(source != null);
Debug.Assert(isStartIndexFromEnd || isEndIndexFromEnd);
Debug.Assert(isStartIndexFromEnd
? startIndex > 0 && (!isEndIndexFromEnd || startIndex > endIndex)
: startIndex >= 0 && (isEndIndexFromEnd || startIndex < endIndex));
Debug.Assert(endIndex >= 0);

using IEnumerator<TSource> e = source.GetEnumerator();
if (isStartIndexFromEnd)
// Attempt to extract the count of the source enumerator,
// in order to convert fromEnd indices to regular indices.
// Enumerable counts can change over time, so it is very
// important that this check happens at enumeration time;
// do not move it outside of the iterator method.
if (source.TryGetNonEnumeratedCount(out int count))
{
if (!e.MoveNext())
{
yield break;
}

int index = 0;
Queue<TSource> queue = new();
queue.Enqueue(e.Current);
startIndex = CalculateStartIndex(isStartIndexFromEnd, startIndex, count);
endIndex = CalculateEndIndex(isEndIndexFromEnd, endIndex, count);

while (e.MoveNext())
if (startIndex < endIndex)
{
checked
foreach (TSource element in TakeRangeIterator(source, startIndex, endIndex))
{
index++;
yield return element;
}
}

yield break;
}

Queue<TSource> queue;

if (queue.Count == startIndex)
if (isStartIndexFromEnd)
{
// TakeLast compat: enumerator should be disposed before yielding the first element.
using (IEnumerator<TSource> e = source.GetEnumerator())
{
if (!e.MoveNext())
{
queue.Dequeue();
yield break;
}

queue = new Queue<TSource>();
queue.Enqueue(e.Current);
}

int count = checked(index + 1);
Debug.Assert(queue.Count == Math.Min(count, startIndex));
count = 1;

startIndex = count - startIndex;
if (startIndex < 0)
{
startIndex = 0;
}
while (e.MoveNext())
{
if (count < startIndex)
{
queue.Enqueue(e.Current);
++count;
}
else
{
do
{
queue.Dequeue();
queue.Enqueue(e.Current);
checked { ++count; }
} while (e.MoveNext());
break;
}
}

if (isEndIndexFromEnd)
{
endIndex = count - endIndex;
}
else if (endIndex > count)
{
endIndex = count;
Debug.Assert(queue.Count == Math.Min(count, startIndex));
}

startIndex = CalculateStartIndex(isStartIndexFromEnd: true, startIndex, count);
endIndex = CalculateEndIndex(isEndIndexFromEnd, endIndex, count);
Debug.Assert(endIndex - startIndex <= queue.Count);

for (int rangeIndex = startIndex; rangeIndex < endIndex; rangeIndex++)
{
yield return queue.Dequeue();
}
}
else
{
int index = 0;
while (index <= startIndex)
{
if (!e.MoveNext())
{
yield break;
}
Debug.Assert(!isStartIndexFromEnd && isEndIndexFromEnd);

checked
{
index++;
}
// SkipLast compat: the enumerator should be disposed at the end of the enumeration.
using IEnumerator<TSource> e = source.GetEnumerator();

count = 0;
while (count < startIndex && e.MoveNext())
{
++count;
}

if (isEndIndexFromEnd)
if (count == startIndex)
{
if (endIndex > 0)
queue = new Queue<TSource>();
while (e.MoveNext())
{
Queue<TSource> queue = new();
do
if (queue.Count == endIndex)
{
if (queue.Count == endIndex)
do
{
queue.Enqueue(e.Current);
yield return queue.Dequeue();
}
} while (e.MoveNext());

queue.Enqueue(e.Current);
} while (e.MoveNext());
}
else
{
do
break;
}
else
{
yield return e.Current;
} while (e.MoveNext());
}
}
else
{
Debug.Assert(index < endIndex);
yield return e.Current;
while (checked(++index) < endIndex && e.MoveNext())
{
yield return e.Current;
queue.Enqueue(e.Current);
}
}
}
}

static int CalculateStartIndex(bool isStartIndexFromEnd, int startIndex, int count) =>
Math.Max(0, isStartIndexFromEnd ? count - startIndex : startIndex);

static int CalculateEndIndex(bool isEndIndexFromEnd, int endIndex, int count) =>
Math.Min(count, isEndIndexFromEnd ? count - endIndex : endIndex);
}

public static IEnumerable<TSource> TakeWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
Expand Down Expand Up @@ -243,50 +251,9 @@ public static IEnumerable<TSource> TakeLast<TSource>(this IEnumerable<TSource> s

return count <= 0 ?
Empty<TSource>() :
TakeLastIterator(source, count);
}

private static IEnumerable<TSource> TakeLastIterator<TSource>(IEnumerable<TSource> source, int count)
{
Debug.Assert(source != null);
Debug.Assert(count > 0);

Queue<TSource> queue;
using (IEnumerator<TSource> e = source.GetEnumerator())
{
if (!e.MoveNext())
{
yield break;
}

queue = new Queue<TSource>();
queue.Enqueue(e.Current);

while (e.MoveNext())
{
if (queue.Count < count)
{
queue.Enqueue(e.Current);
}
else
{
do
{
queue.Dequeue();
queue.Enqueue(e.Current);
}
while (e.MoveNext());
break;
}
}
}

Debug.Assert(queue.Count <= count);
do
{
yield return queue.Dequeue();
}
while (queue.Count > 0);
TakeRangeFromEndIterator(source,
isStartIndexFromEnd: true, startIndex: count,
isEndIndexFromEnd: true, endIndex: 0);
}
}
}
Loading