Skip to content

Commit fa96188

Browse files
committed
Add more debugging lines
1 parent 04e01c3 commit fa96188

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
6161
} // namespace
6262

6363
XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
64+
std::cout << "WONJOO: at aten_xla_bridge.cpp, TryGetXlaTensor" << std::endl;
6465
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
6566
if (impl == nullptr) {
6667
return XLATensorPtr();
@@ -69,17 +70,20 @@ XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
6970
}
7071

7172
bool IsXlaTensor(const at::Tensor& tensor) {
73+
std::cout << "WONJOO: at aten_xla_bridge.cpp, IsXlaTensor" << std::endl;
7274
return GetXlaTensorImpl(tensor) != nullptr;
7375
}
7476

7577
XLATensorPtr GetXlaTensor(const at::Tensor& tensor) {
78+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetXlaTensor" << std::endl;
7679
auto xtensor = TryGetXlaTensor(tensor);
7780
XLA_CHECK(xtensor) << "Input tensor is not an XLA tensor: "
7881
<< tensor.toString();
7982
return xtensor;
8083
}
8184

8285
void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
86+
std::cout << "WONJOO: at aten_xla_bridge.cpp, ReplaceXlaTensor" << std::endl;
8387
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
8488
XLATensorImpl* impl =
8589
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
@@ -89,6 +93,7 @@ void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
8993
}
9094

9195
std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
96+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetXlaTensors" << std::endl;
9297
std::vector<XLATensorPtr> xla_tensors;
9398
xla_tensors.reserve(tensors.size());
9499
for (const auto& tensor : tensors) {
@@ -99,6 +104,7 @@ std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
99104

100105
torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
101106
const at::Tensor& tensor, const torch::lazy::BackendDevice& device) {
107+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber" << std::endl;
102108
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
103109
(tensor.dim() == 0 && tensor.numel() == 1)) {
104110
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
@@ -109,6 +115,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
109115

110116
XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
111117
const torch::lazy::BackendDevice& device) {
118+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
112119
if (!tensor.defined()) {
113120
return XLATensorPtr();
114121
}
@@ -122,6 +129,7 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
122129

123130
XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
124131
const torch::lazy::BackendDevice& device) {
132+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
125133
if (!IsDefined(tensor)) {
126134
return XLATensorPtr();
127135
}
@@ -133,6 +141,7 @@ XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
133141
std::vector<XLATensorPtr> GetOrCreateXlaTensors(
134142
absl::Span<const at::Tensor> tensors,
135143
const torch::lazy::BackendDevice& device) {
144+
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors" << std::endl;
136145
std::vector<XLATensorPtr> xla_tensors;
137146
for (const at::Tensor& tensor : tensors) {
138147
xla_tensors.push_back(bridge::GetOrCreateXlaTensor(tensor, device));
@@ -141,6 +150,7 @@ std::vector<XLATensorPtr> GetOrCreateXlaTensors(
141150
}
142151

143152
std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
153+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList" << std::endl;
144154
std::vector<at::Tensor> aten_xla_tensors(tensors.size());
145155
std::vector<XLATensorPtr> xla_tensors;
146156
// We need to separate out the defined tensors first, GetXlaTensor() doesn't
@@ -179,6 +189,7 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
179189

180190
std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList(
181191
const std::vector<c10::optional<at::Tensor>>& tensors) {
192+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList" << std::endl;
182193
std::vector<c10::optional<at::Tensor>> opt_aten_xla_tensors(tensors.size());
183194
std::vector<at::Tensor> materialized_tensors;
184195
std::vector<bool> to_translate(tensors.size());
@@ -202,6 +213,7 @@ std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList(
202213
void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
203214
absl::Span<const at::Tensor> source_cpu_tensors,
204215
absl::Span<const size_t> indices) {
216+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaUpdateTensors" << std::endl;
205217
for (auto index : indices) {
206218
at::Tensor dest = dest_xla_tensors.at(index);
207219
at::Tensor source = source_cpu_tensors.at(index);
@@ -332,6 +344,7 @@ c10::Device GetCurrentAtenDevice() {
332344

333345
at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
334346
const at::TensorOptions& tensor_options) {
347+
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaToAtenTensor" << std::endl;
335348
if (tensor_options.has_device()) {
336349
XLA_CHECK_NE(tensor_options.device().type(), at::kXLA);
337350
}
@@ -343,6 +356,7 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
343356
}
344357

345358
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
359+
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensor" << std::endl;
346360
if (xla_tensor) {
347361
auto out =
348362
at::Tensor(c10::make_intrusive<XLATensorImpl>(std::move(xla_tensor)));
@@ -364,6 +378,7 @@ at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
364378

365379
std::vector<at::Tensor> AtenFromXlaTensors(
366380
absl::Span<const XLATensorPtr> xla_tensors) {
381+
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors" << std::endl;
367382
std::vector<at::Tensor> tensors;
368383
tensors.reserve(xla_tensors.size());
369384
for (auto& tensor : xla_tensors) {
@@ -375,6 +390,7 @@ std::vector<at::Tensor> AtenFromXlaTensors(
375390
at::Tensor CreateXlaTensor(
376391
at::Tensor tensor,
377392
const c10::optional<torch::lazy::BackendDevice>& device) {
393+
std::cout << "WONJOO: at aten_xla_bridge.cpp, CreateXlaTensor" << std::endl;
378394
if (tensor.defined() && device) {
379395
XLATensorPtr xla_tensor = XLATensor::Create(std::move(tensor), *device);
380396
tensor = AtenFromXlaTensor(xla_tensor);
@@ -385,6 +401,7 @@ at::Tensor CreateXlaTensor(
385401
std::vector<at::Tensor> CreateXlaTensors(
386402
const std::vector<at::Tensor>& tensors,
387403
const c10::optional<torch::lazy::BackendDevice>& device) {
404+
std::cout << "WONJOO: at aten_xla_bridge.cpp, CreateXlaTensors" << std::endl;
388405
std::vector<at::Tensor> xtensors;
389406
for (auto& tensor : tensors) {
390407
xtensors.push_back(CreateXlaTensor(tensor, device));

0 commit comments

Comments
 (0)