Skip to content

Commit 0032fa7

Browse files
bdhirshfacebook-github-bot
authored andcommitted
Add a Functionalization pass in core (#64432)
Summary: Pull Request resolved: #64432 Original PR description + feedback here: #63048 I've addressed all of the feedback in the original PR and made some pretty large changes, listed below. **Table of Contents** - Starting points - List of the main changes from the original PR - Next Steps - Example codegen output (for a view, mutation, and view+mutation op) **Starting Points** A good place to start when looking through the PR: * Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass. * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement. * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))` * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic. * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen. * documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large) * documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12 * Reading through the codegen output at the bottom of this description. **Main changes from the original PR** (1) I use lambdas instead of a giant enum to handle all of the different views. This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`) (2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`. This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now. (3) `FunctionalTensorWrapper` objects accurately report stride information. It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping. To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it. (4) `FunctionalTensorWrapper` objects accurately report aliasing information. There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set. Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)? (5) better docs :) **View operator coverage** (6) The functionalization pass now gets math-composite view ops for free. I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation. There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets. (7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}). From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation (8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`. These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op. The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()). I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`). I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing. Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though: * the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators). * For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites): * select * slice * diagonal * as_stridied * split * split_with_sizes A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though. **Current State + Next Steps** There are a bunch of followups after this PR eventually lands. Roughly in order: * Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it). * Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys * Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway. * Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage. **Example Codegen Output** View Op: ``` ::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) { auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); ::std::vector<at::Tensor> out; { at::AutoDispatchBelowFunctionalize guard; auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim); out = at::functionalization::impl::wrapFunctionalTensor(tmp_output); // I'm fusing the [alias removal], [mutation removal], [add views back] passes together. // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal). } at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.split(split_size, dim)[mutated_view_idx]; }, [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim); } ); at::functionalization::impl::set_view_meta(out, self, view_meta); at::AutoDispatchDirectlyToNative native_guard; ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim); at::functionalization::impl::set_strides(out, reference_tensor_output); return out; } ``` Mutation Op: ``` at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { at::functionalization::impl::sync(self); at::functionalization::impl::sync(other); auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other); at::Tensor tmp_output; { at::AutoDispatchBelowFunctionalize guard; // The functionalization pass explicitly doesn't pass out= parameters to the redispatch tmp_output = at::redispatch::add( ks & c10::after_func_keyset, self_, other_, alpha); } self.replace_(tmp_output); at::functionalization::impl::maybe_add_update(self); return self; } ``` View + Mutation Op: ``` at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) { at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.transpose(dim0, dim1); }, [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1); } ); at::functionalization::impl::mutate_view_meta(self, view_meta); // See Note [Propagating strides in the functionalization pass] // Directly update the sizes/strides/storage_offset fields on self using the inplace call. // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels. // Its only job is to directly compute the output size/stride/storage_offset metadata. at::AutoDispatchDirectlyToNative native_guard; at::native::transpose_(self, dim0, dim1); return self; } ``` Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942093 Pulled By: bdhirsh fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf
1 parent b0a8ca2 commit 0032fa7

28 files changed

+1809
-33
lines changed

BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ genrule(
127127
"aten/src/ATen/Declarations.yaml",
128128
"aten/src/ATen/RegisterBackendSelect.cpp",
129129
"aten/src/ATen/RegisterCPU.cpp",
130+
"aten/src/ATen/RegisterFunctionalization_0.cpp",
131+
"aten/src/ATen/RegisterFunctionalization_1.cpp",
132+
"aten/src/ATen/RegisterFunctionalization_2.cpp",
133+
"aten/src/ATen/RegisterFunctionalization_3.cpp",
134+
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
130135
"aten/src/ATen/RegisterMkldnnCPU.cpp",
131136
"aten/src/ATen/RegisterQuantizedCPU.cpp",
132137
"aten/src/ATen/RegisterSparseCPU.cpp",
@@ -143,6 +148,7 @@ genrule(
143148
"aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
144149
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
145150
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
151+
"aten/src/ATen/FunctionalInverses.h",
146152
"aten/src/ATen/Functions.h",
147153
"aten/src/ATen/Functions.cpp",
148154
"aten/src/ATen/RedispatchFunctions.h",
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
2+
#include <ATen/FunctionalInverses.h>
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/ExpandUtils.h>
6+
namespace at {
7+
namespace functionalization {
8+
9+
// This logic is similar to autograd code for view backwards calls.
10+
// We can't easily share it though, because (eventually) these functions
11+
// will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`.
12+
13+
Tensor permute_inverse(const Tensor& self, IntArrayRef dims) {
14+
// invert the permutation
15+
auto ndims = dims.size();
16+
std::vector<int64_t> dims_(ndims);
17+
for(const auto i : c10::irange(ndims)) {
18+
dims_[at::maybe_wrap_dim(dims[i], ndims)] = i;
19+
}
20+
return self.permute(dims_);
21+
}
22+
23+
Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
24+
auto result = self;
25+
26+
int64_t nDims = sizes.size();
27+
for(const auto dim : c10::irange(nDims)) {
28+
if (sizes[dim] == 1) {
29+
result = result.unsqueeze(dim);
30+
}
31+
}
32+
return result;
33+
}
34+
35+
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
36+
dim = at::maybe_wrap_dim(dim, sizes.size());
37+
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
38+
// unsqueezing in the backward.
39+
if (sizes.size() > 0 && sizes[dim] == 1) {
40+
return self.unsqueeze(dim);
41+
}
42+
return self;
43+
}
44+
45+
// Note [Functionalization Pass: View Inverses].
46+
// This file contains the implementation of each "view inverse".
47+
// These aren't really true inverses in the mathematically sense: each view inverse describes how to undo
48+
// the original view (although it takes in different arguments).
49+
//
50+
// E.g. Below is an example of a program that has alias operations removed, and the role that view inverses play:
51+
//
52+
// normal program with views and mutations:
53+
// view1 = input1.view_op(args...)
54+
// view1.add_(1) (perform a mutation on the view, which should also modify input)
55+
56+
// version of the program with no aliasing, that instead uses view_inverse functions:
57+
// view_copy1 = input1.view_copy_op(args...)
58+
// view_copy1.add_(1) (perform a mutation on view_copy1. At this point, input1 is NOT modified)
59+
// x = view_op_inverse(input1, view_copy1, args...)
60+
//
61+
// at this point, input1 and x should be equal
62+
//
63+
// Note that input1 is also passed as an argument to view_op_inverse in the above example.
64+
// This isn't actually required for most view operators: it's only required for view ops
65+
// where you can't figure out what the size of the base tensor is given just the view tensor and arguments.
66+
// Examples are slice/select/scatter/squeeze/as_strided.
67+
// We happen to be passing in the base tensor in all cases, mostly to make the codegen simpler.
68+
// But you'll see below that the "base" argument is ignored by most view_inverse implementations.
69+
70+
// ----------------------------------------------------------
71+
// Implementations of each view_inverse() function are below.
72+
// One of these needs to be implemented for every existing non-composite view operator.
73+
// The codegen automatically generates the corresponding function declaration.
74+
// ----------------------------------------------------------
75+
76+
Tensor FunctionalInverses::_fw_primal_inverse(const at::Tensor& base, const at::Tensor& mutated_view, int64_t level) {
77+
TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported.");
78+
return Tensor();
79+
}
80+
81+
Tensor FunctionalInverses::view_as_real_inverse(const Tensor& base, const Tensor& mutated_view) {
82+
return at::view_as_complex(mutated_view);
83+
}
84+
85+
Tensor FunctionalInverses::view_as_complex_inverse(const Tensor& base, const Tensor& mutated_view) {
86+
return at::view_as_real(mutated_view.resolve_conj());
87+
}
88+
89+
Tensor FunctionalInverses::_conj_inverse(const Tensor& base, const Tensor& mutated_view) {
90+
return mutated_view.conj();
91+
}
92+
93+
Tensor FunctionalInverses::_neg_view_inverse(const Tensor& base, const Tensor& mutated_view) {
94+
return mutated_view.neg();
95+
}
96+
97+
Tensor FunctionalInverses::as_strided_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
98+
TORCH_INTERNAL_ASSERT(false, "as_strided has not been implemented in the functionalization pass yet");
99+
return Tensor();
100+
}
101+
102+
Tensor FunctionalInverses::diagonal_inverse(const Tensor& base, const Tensor& mutated_view, int64_t offset, int64_t dim1, int64_t dim2) {
103+
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
104+
}
105+
106+
Tensor FunctionalInverses::expand_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size, bool implicit) {
107+
return at::sum_to(mutated_view, base.sizes());
108+
}
109+
110+
Tensor FunctionalInverses::permute_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef dims) {
111+
return at::functionalization::permute_inverse(mutated_view, dims);
112+
}
113+
114+
Tensor FunctionalInverses::_reshape_alias_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size, at::IntArrayRef stride) {
115+
// Note that I'm directly calling reshape(), and ignoring the strides.
116+
// _reshape_alias() isn't available from user code, and is an implementation detail of reshape().
117+
// Specifically, passing in the strides directly can get us into trouble in cases like:
118+
// b = a[0]; c = b.reshape(...); c.add_(1); print(a)
119+
// When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides,
120+
// The call would fail because `mutated_view` doesn't have enough bytes of storage.
121+
return mutated_view.reshape(base.sizes());
122+
}
123+
124+
Tensor FunctionalInverses::select_int_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dim, int64_t index) {
125+
return base.select_scatter(mutated_view, dim, index);
126+
}
127+
Tensor FunctionalInverses::detach_inverse(const Tensor& base, const Tensor& mutated_view) {
128+
// the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function
129+
return mutated_view;
130+
}
131+
132+
Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
133+
return base.slice_scatter(mutated_view, dim, start, end, step);
134+
}
135+
136+
Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, int64_t mutated_view_idx, int64_t split_size, int64_t dim) {
137+
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
138+
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
139+
// on top of the base tensor.
140+
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
141+
dim = at::maybe_wrap_dim(dim, base.sizes().size());
142+
auto dim_size = base.size(dim);
143+
auto start = mutated_view_idx * split_size;
144+
auto end = start + split_size;
145+
if (end > dim_size) end = dim_size;
146+
return base.slice_scatter(mutated_view, dim, start, end, 1);
147+
}
148+
149+
Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, int64_t mutated_view_idx, at::IntArrayRef split_sizes, int64_t dim) {
150+
dim = at::maybe_wrap_dim(dim, base.sizes().size());
151+
auto dim_size = base.size(dim);
152+
int64_t start = 0;
153+
for (auto i = 0; i < mutated_view_idx; ++i) {
154+
start += split_sizes[i];
155+
}
156+
auto end = start + split_sizes[mutated_view_idx];
157+
if (end > dim_size) end = dim_size;
158+
return base.slice_scatter(mutated_view, dim, start, end, 1);
159+
}
160+
161+
Tensor FunctionalInverses::squeeze_inverse(const Tensor& base, const Tensor& mutated_view) {
162+
return unsqueeze_to(mutated_view, base.sizes());
163+
}
164+
165+
Tensor FunctionalInverses::squeeze_dim_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dim) {
166+
return unsqueeze_to(mutated_view, dim, base.sizes());
167+
}
168+
169+
Tensor FunctionalInverses::t_inverse(const Tensor& base, const Tensor& mutated_view) {
170+
return mutated_view.t();
171+
}
172+
173+
Tensor FunctionalInverses::transpose_int_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dim0, int64_t dim1) {
174+
return mutated_view.transpose(dim0, dim1);
175+
}
176+
177+
Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dim) {
178+
return mutated_view.squeeze(dim);
179+
}
180+
181+
Tensor FunctionalInverses::_indices_inverse(const Tensor& base, const Tensor& mutated_view) {
182+
TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
183+
return Tensor();
184+
}
185+
186+
Tensor FunctionalInverses::_values_inverse(const Tensor& base, const Tensor& mutated_view) {
187+
TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
188+
return Tensor();
189+
}
190+
191+
Tensor FunctionalInverses::indices_inverse(const Tensor& base, const Tensor& mutated_view) {
192+
TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
193+
return Tensor();
194+
}
195+
196+
Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& mutated_view) {
197+
TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
198+
return Tensor();
199+
}
200+
201+
Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view) {
202+
TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
203+
return Tensor();
204+
}
205+
206+
Tensor FunctionalInverses::col_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view) {
207+
TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
208+
return Tensor();
209+
}
210+
211+
Tensor FunctionalInverses::unbind_int_inverse(const Tensor& base, const Tensor& mutated_view, int64_t mutated_view_idx, int64_t dim) {
212+
dim = at::maybe_wrap_dim(dim, base.sizes().size());
213+
return base.select_scatter(mutated_view, dim, mutated_view_idx);
214+
}
215+
216+
Tensor FunctionalInverses::view_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size) {
217+
return mutated_view.view(base.sizes());
218+
}
219+
220+
Tensor FunctionalInverses::view_dtype_inverse(const Tensor& base, const Tensor& mutated_view, at::ScalarType dtype) {
221+
return mutated_view.view(base.scalar_type());
222+
}
223+
224+
Tensor FunctionalInverses::unfold_inverse(const Tensor& base, const Tensor& mutated_view, int64_t dimension, int64_t size, int64_t step) {
225+
// I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm.
226+
return unfold_backward(mutated_view, base.sizes(), dimension, size, step);
227+
}
228+
229+
Tensor FunctionalInverses::alias_inverse(const Tensor& base, const Tensor& mutated_view) {
230+
return mutated_view;
231+
}
232+
233+
} // functionalization
234+
} // at
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <ATen/FunctionalStorageImpl.h>
2+
3+
#include <ATen/FunctionalTensorWrapper.h>
4+
#include <ATen/core/LegacyTypeDispatch.h>
5+
#include <c10/util/Exception.h>
6+
#include <vector>
7+
8+
namespace at {
9+
namespace functionalization {
10+
11+
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
12+
if (out_idx == this->out_index) return *this;
13+
return ViewMeta(forward_fn, reverse_fn, out_idx);
14+
}
15+
16+
Alias::Alias(const at::Tensor& base) {
17+
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base));
18+
base_ = base;
19+
}
20+
21+
const at::Tensor& Alias::base() const {
22+
return base_;
23+
}
24+
25+
void Alias::add_update(const at::Tensor& updated_val, const std::vector<ViewMeta>& metas) {
26+
updates_.push_back({updated_val, metas});
27+
generation_++;
28+
}
29+
30+
// Note [Functionalization: Alias Removal Part 2]
31+
// See Note [Functionalization: Alias Removal] for more details.
32+
// This function applies a single update from one of the views to the Alias object.
33+
// We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
34+
// Consider this program:
35+
//
36+
// base = ...
37+
// a = base.view1()
38+
// b = a.view2()
39+
// c = b.view3()
40+
// c.add_(3)
41+
//
42+
// Then the functionalization pass will queue an update as follows:
43+
//
44+
// update.new_val = c # the updated value of c
45+
// update.view_metas = [view1_meta, view2_meta, view3_meta]
46+
//
47+
// Syncing any of a, b or c will eventually call apply_update() on the alias, and the following will run:
48+
//
49+
// tmp_values = [base, a, b] # NB: c is not necessary
50+
// t = update.new_val
51+
// t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
52+
// t = view2_inverse(a, t, 0)
53+
// t = view1_inverse(base, t, 0) # t now represents the updated alias.
54+
// alias.base_ = t
55+
const Tensor apply_update(const Alias::Update& update, const Tensor& base) {
56+
at::Tensor t = update.new_val;
57+
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
58+
std::vector<at::Tensor> tmp_values({base});
59+
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
60+
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
61+
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
62+
// All of these ops require additional information to recover the sizes of the original tensor.
63+
// If need to, we could probably apply this optimization and only bother computing tmp_values
64+
// for those necessary view ops.
65+
tmp_values.push_back(std::move(next_view));
66+
}
67+
for(int i = update.view_metas.size()-1; i >= 0; --i) {
68+
int64_t out_idx = update.view_metas[i].out_index;
69+
// Each view inverse is implemented in ViewInverses.cpp.
70+
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
71+
}
72+
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
73+
return t;
74+
}
75+
76+
void Alias::apply_updates() {
77+
// N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
78+
// The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
79+
// It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
80+
// which means that we need to explicitly exclude it here before doing any other work underneath the pass.
81+
at::AutoDispatchSkipFunctionalize guard;
82+
for (auto& update_data: updates_) {
83+
base_ = apply_update(update_data, base_);
84+
}
85+
updates_.clear();
86+
}
87+
88+
FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& value)
89+
: c10::StorageImpl(
90+
c10::StorageImpl::use_byte_size_t(),
91+
value.numel() * value.dtype().itemsize(),
92+
DataPtr{nullptr, value.device()},
93+
// Using a null allocator, since FunctionalTensorImpl's aren't resizeable.
94+
nullptr,
95+
/*resizeable=*/false
96+
),
97+
alias_(Alias(value))
98+
{}
99+
100+
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& view_metas) {
101+
alias_.add_update(updated_val, view_metas);
102+
}
103+
104+
void FunctionalStorageImpl::apply_updates() {
105+
alias_.apply_updates();
106+
}
107+
108+
const Tensor& FunctionalStorageImpl::base() {
109+
return alias_.base();
110+
}
111+
112+
size_t FunctionalStorageImpl::generation() const {
113+
return alias_.generation();
114+
}
115+
116+
} // namespace functionalization
117+
} // namespace at

0 commit comments

Comments
 (0)