Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,60 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
}
Span<T> dstSpan = MemoryMarshal.CreateSpan(ref destination._reference, (int)totalLength);

for (int i = 0; i < tensors.Length; i++)
if (dimension == 0 || dimension == -1)
{
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(tensors[i], dstSpan);
dstSpan = dstSpan.Slice((int)tensors[i].FlattenedLength);
for (int i = 0; i < tensors.Length; i++)
{
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(tensors[i], dstSpan);
dstSpan = dstSpan.Slice((int)tensors[i].FlattenedLength);
}
}
else
{
Span<NRange> ranges = TensorOperation.RentedBuffer.CreateUninitialized(destination.Rank, out TensorOperation.RentedBuffer<NRange> rentedBuffer);
for (int i = 0; i < dimension; i++)
{
ranges[i] = 0..1;
}
for (int i = dimension; i < destination.Rank; i++)
{
ranges[i] = ..;
}

bool hasMore = true;
while (hasMore)
{
for (int i = 0; i < tensors.Length; i++)
{
Tensor<T> slice = tensors[i].Slice(ranges);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(slice, dstSpan);
dstSpan = dstSpan.Slice((int)slice.FlattenedLength);
}
hasMore = IncrementIndexes(ranges, dimension, destination.Lengths);
}
rentedBuffer.Dispose();
}
return ref destination;
}

private static bool IncrementIndexes(Span<NRange> ranges, int dimension, ReadOnlySpan<nint> lengths)
{
NRange curRange = ranges[dimension - 1];
ranges[dimension - 1] = new NRange(curRange.Start.Value + 1, curRange.End.Value + 1);

for (int i = dimension - 1; i >= 0; i--)
{
if (ranges[i].Start.Value >= lengths[i])
{
ranges[i] = 0..1;
if (i == 0)
return false;
ranges[i - 1] = new NRange(ranges[i - 1].Start.Value + 1, ranges[i - 1].End.Value + 1);
}
}
return true;
}

private static nint CalculateCopyLength(ReadOnlySpan<nint> lengths, int startingAxis)
{
// When starting axis is -1 we want all the data at once same as if starting axis is 0
Expand Down Expand Up @@ -4797,9 +4842,8 @@ public static ref readonly TensorSpan<T> Negate<T>(scoped in ReadOnlyTensorSpan<
public static T Norm<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : IRootFunctions<T>
{
// TODO: TANNER ADVICE
T result = T.AdditiveIdentity;
TensorOperation.Invoke<TensorOperation.SumOfSquaredDifferences<T>, T, T>(x, T.AdditiveIdentity, ref result);
TensorOperation.Invoke<TensorOperation.SumOfSquares<T>, T, T>(x, ref result);
return T.Sqrt(result);
}
#endregion
Expand Down
Loading