Skip to content

Commit 9e5f141

Browse files
committed
Update XLATensor namespace calls to tensor_methods
1 parent a81964f commit 9e5f141

File tree

3 files changed

+48
-35
lines changed

3 files changed

+48
-35
lines changed

torch_xla/csrc/aten_cpu_fallback.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ static std::unordered_map<std::string, ::xla::metrics::Counter*>
1717
_cpu_fallback_counters;
1818

1919
void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
20-
std::cout << "WONJOO: at aten_cpu_fallback.cpp, xla_cpu_fallback1" << std::endl;
20+
std::cout << "WONJOO: at aten_cpu_fallback.cpp, xla_cpu_fallback1"
21+
<< std::endl;
2122
XLA_FN_TRACK(3);
2223
const auto name = c10::toString(op.operator_name());
2324

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
105105

106106
torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
107107
const at::Tensor& tensor, const torch::lazy::BackendDevice& device) {
108-
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber" << std::endl;
108+
std::cout
109+
<< "WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber"
110+
<< std::endl;
109111
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
110112
(tensor.dim() == 0 && tensor.numel() == 1)) {
111113
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
@@ -116,7 +118,8 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
116118

117119
XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
118120
const torch::lazy::BackendDevice& device) {
119-
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
121+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor"
122+
<< std::endl;
120123
if (!tensor.defined()) {
121124
return XLATensorPtr();
122125
}
@@ -130,7 +133,8 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
130133

131134
XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
132135
const torch::lazy::BackendDevice& device) {
133-
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
136+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor"
137+
<< std::endl;
134138
if (!IsDefined(tensor)) {
135139
return XLATensorPtr();
136140
}
@@ -142,7 +146,8 @@ XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
142146
std::vector<XLATensorPtr> GetOrCreateXlaTensors(
143147
absl::Span<const at::Tensor> tensors,
144148
const torch::lazy::BackendDevice& device) {
145-
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors" << std::endl;
149+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors"
150+
<< std::endl;
146151
std::vector<XLATensorPtr> xla_tensors;
147152
for (const at::Tensor& tensor : tensors) {
148153
xla_tensors.push_back(bridge::GetOrCreateXlaTensor(tensor, device));
@@ -151,7 +156,8 @@ std::vector<XLATensorPtr> GetOrCreateXlaTensors(
151156
}
152157

153158
std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
154-
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList" << std::endl;
159+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList"
160+
<< std::endl;
155161
std::vector<at::Tensor> aten_xla_tensors(tensors.size());
156162
std::vector<XLATensorPtr> xla_tensors;
157163
// We need to separate out the defined tensors first, GetXlaTensor() doesn't
@@ -191,7 +197,8 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
191197

192198
std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList(
193199
const std::vector<c10::optional<at::Tensor>>& tensors) {
194-
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList" << std::endl;
200+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList"
201+
<< std::endl;
195202
std::vector<c10::optional<at::Tensor>> opt_aten_xla_tensors(tensors.size());
196203
std::vector<at::Tensor> materialized_tensors;
197204
std::vector<bool> to_translate(tensors.size());
@@ -380,7 +387,8 @@ at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
380387

381388
std::vector<at::Tensor> AtenFromXlaTensors(
382389
absl::Span<const XLATensorPtr> xla_tensors) {
383-
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors" << std::endl;
390+
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors"
391+
<< std::endl;
384392
std::vector<at::Tensor> tensors;
385393
tensors.reserve(xla_tensors.size());
386394
for (auto& tensor : xla_tensors) {

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,10 @@ at::Tensor XLANativeFunctions::as_strided_scatter(
684684
storage_offset);
685685
}
686686
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
687-
auto base_clone = XLATensor::clone(base_);
688-
auto base_clone_slice = XLATensor::as_strided(
687+
auto base_clone = tensor_methods::clone(base_);
688+
auto base_clone_slice = tensor_methods::as_strided(
689689
base_clone, xsize, xstride, XlaHelpers::I64Optional(storage_offset));
690-
XLATensor::copy_(base_clone_slice, mutated_view_);
690+
tensor_methods::copy_(base_clone_slice, mutated_view_);
691691
return bridge::AtenFromXlaTensor(base_clone);
692692
}
693693

@@ -1041,9 +1041,10 @@ at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base,
10411041
int64_t dim2) {
10421042
auto base_ = bridge::GetXlaTensor(base);
10431043
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
1044-
auto base_clone = XLATensor::clone(base_);
1045-
auto base_clone_slice = XLATensor::diagonal(base_clone, offset, dim1, dim2);
1046-
XLATensor::copy_(base_clone_slice, mutated_view_);
1044+
auto base_clone = tensor_methods::clone(base_);
1045+
auto base_clone_slice =
1046+
tensor_methods::diagonal(base_clone, offset, dim1, dim2);
1047+
tensor_methods::copy_(base_clone_slice, mutated_view_);
10471048
return bridge::AtenFromXlaTensor(base_clone);
10481049
}
10491050

@@ -2167,7 +2168,7 @@ at::Tensor& XLANativeFunctions::normal_(
21672168
}
21682169

21692170
at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self,
2170-
at::IntArrayRef dims) {
2171+
at::IntArrayRef dims) {
21712172
TORCH_LAZY_FN_COUNTER("xla::");
21722173
return bridge::AtenFromXlaTensor(tensor_methods::permute(
21732174
bridge::GetXlaTensor(self), XlaHelpers::I64List(dims)));
@@ -2544,10 +2545,10 @@ at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim,
25442545
}
25452546

25462547
at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim,
2547-
int64_t index) {
2548+
int64_t index) {
25482549
TORCH_LAZY_FN_COUNTER("xla::");
25492550
std::cout << "WONJOO: at XLANativeFunctions::select_copy1" << std::endl;
2550-
XLA_FN_COUNTER("xla::");
2551+
TORCH_LAZY_FN_COUNTER("xla::");
25512552
return bridge::AtenFromXlaTensor(
25522553
tensor_methods::select(bridge::GetXlaTensor(self), dim, index));
25532554
}
@@ -2564,9 +2565,9 @@ at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base,
25642565
<< std::endl;
25652566
auto base_ = bridge::GetXlaTensor(base);
25662567
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
2567-
auto base_clone = XLATensor::clone(base_);
2568-
auto base_clone_slice = XLATensor::select(base_clone, dim, index);
2569-
XLATensor::copy_(base_clone_slice, mutated_view_);
2568+
auto base_clone = tensor_methods::clone(base_);
2569+
auto base_clone_slice = tensor_methods::select(base_clone, dim, index);
2570+
tensor_methods::copy_(base_clone_slice, mutated_view_);
25702571
return bridge::AtenFromXlaTensor(base_clone);
25712572
}
25722573

@@ -2592,8 +2593,9 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
25922593
}
25932594

25942595
at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim,
2595-
c10::optional<int64_t> start,
2596-
c10::optional<int64_t> end, int64_t step) {
2596+
c10::optional<int64_t> start,
2597+
c10::optional<int64_t> end,
2598+
int64_t step) {
25972599
TORCH_LAZY_FN_COUNTER("xla::");
25982600
int64_t start_val = start.has_value() ? start.value() : 0;
25992601
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
@@ -2606,12 +2608,12 @@ at::Tensor XLANativeFunctions::slice_scatter(
26062608
c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
26072609
auto base_ = bridge::GetXlaTensor(base);
26082610
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
2609-
auto base_clone = XLATensor::clone(base_);
2611+
auto base_clone = tensor_methods::clone(base_);
26102612
int64_t start_val = start.has_value() ? start.value() : 0;
26112613
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
26122614
auto base_clone_slice =
2613-
XLATensor::slice(base_clone, dim, start_val, end_val, step);
2614-
XLATensor::copy_(base_clone_slice, mutated_view_);
2615+
tensor_methods::slice(base_clone, dim, start_val, end_val, step);
2616+
tensor_methods::copy_(base_clone_slice, mutated_view_);
26152617
return bridge::AtenFromXlaTensor(base_clone);
26162618
}
26172619

@@ -2683,8 +2685,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::sort(
26832685
}
26842686

26852687
std::vector<at::Tensor> XLANativeFunctions::split_copy(const at::Tensor& self,
2686-
int64_t split_size,
2687-
int64_t dim) {
2688+
int64_t split_size,
2689+
int64_t dim) {
26882690
TORCH_LAZY_FN_COUNTER("xla::");
26892691
auto xla_tensors =
26902692
tensor_methods::split(bridge::GetXlaTensor(self), split_size, dim);
@@ -2711,7 +2713,8 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) {
27112713
tensor_methods::squeeze(bridge::GetXlaTensor(self)));
27122714
}
27132715

2714-
at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, int64_t dim) {
2716+
at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self,
2717+
int64_t dim) {
27152718
TORCH_LAZY_FN_COUNTER("xla::");
27162719
return bridge::AtenFromXlaTensor(
27172720
tensor_methods::squeeze(bridge::GetXlaTensor(self), dim));
@@ -2882,8 +2885,8 @@ at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
28822885
tensor_methods::trace(bridge::GetXlaTensor(self)));
28832886
}
28842887

2885-
at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, int64_t dim0,
2886-
int64_t dim1) {
2888+
at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self,
2889+
int64_t dim0, int64_t dim1) {
28872890
TORCH_LAZY_FN_COUNTER("xla::");
28882891
return bridge::AtenFromXlaTensor(
28892892
tensor_methods::transpose(bridge::GetXlaTensor(self), dim0, dim1));
@@ -2903,7 +2906,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::triangular_solve(
29032906
}
29042907

29052908
std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,
2906-
int64_t dim) {
2909+
int64_t dim) {
29072910
TORCH_LAZY_FN_COUNTER("xla::");
29082911
return bridge::AtenFromXlaTensors(
29092912
tensor_methods::unbind(bridge::GetXlaTensor(self), dim));
@@ -2923,7 +2926,8 @@ at::Tensor& XLANativeFunctions::uniform_(
29232926
return self;
29242927
}
29252928

2926-
at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) {
2929+
at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self,
2930+
int64_t dim) {
29272931
TORCH_LAZY_FN_COUNTER("xla::");
29282932
return bridge::AtenFromXlaTensor(
29292933
tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim));
@@ -3139,9 +3143,9 @@ at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self,
31393143
pixel_unshuffle)>::call(self, downscale_factor);
31403144
}
31413145

3142-
at::Tensor XLANativeFunctions::select_backward_symint(const at::Tensor& grad_output,
3143-
at::IntArrayRef input_sizes,
3144-
int64_t dim, int64_t index) {
3146+
at::Tensor XLANativeFunctions::select_backward_symint(
3147+
const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim,
3148+
c10::SymInt index) {
31453149
std::cout << "WONJOO: at XLANativeFunctions::select_backward" << std::endl;
31463150
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
31473151
select_backward)>::call(grad_output, input_sizes, dim, index);

0 commit comments

Comments
 (0)