@@ -61,6 +61,7 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
6161} // namespace
6262
6363XLATensorPtr TryGetXlaTensor (const at::Tensor& tensor) {
64+ std::cout << " WONJOO: at aten_xla_bridge.cpp, TryGetXlaTensor" << std::endl;
6465 XLATensorImpl* impl = GetXlaTensorImpl (tensor);
6566 if (impl == nullptr ) {
6667 return XLATensorPtr ();
@@ -69,17 +70,20 @@ XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
6970}
7071
7172bool IsXlaTensor (const at::Tensor& tensor) {
73+ std::cout << " WONJOO: at aten_xla_bridge.cpp, IsXlaTensor" << std::endl;
7274 return GetXlaTensorImpl (tensor) != nullptr ;
7375}
7476
7577XLATensorPtr GetXlaTensor (const at::Tensor& tensor) {
78+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetXlaTensor" << std::endl;
7679 auto xtensor = TryGetXlaTensor (tensor);
7780 XLA_CHECK (xtensor) << " Input tensor is not an XLA tensor: "
7881 << tensor.toString ();
7982 return xtensor;
8083}
8184
8285void ReplaceXlaTensor (const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
86+ std::cout << " WONJOO: at aten_xla_bridge.cpp, ReplaceXlaTensor" << std::endl;
8387 auto inner_tensor = torch::lazy::maybe_unwrap_functional (tensor);
8488 XLATensorImpl* impl =
8589 dynamic_cast <XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl ());
@@ -89,6 +93,7 @@ void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
8993}
9094
9195std::vector<XLATensorPtr> GetXlaTensors (const at::ITensorListRef& tensors) {
96+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetXlaTensors" << std::endl;
9297 std::vector<XLATensorPtr> xla_tensors;
9398 xla_tensors.reserve (tensors.size ());
9499 for (const auto & tensor : tensors) {
@@ -99,6 +104,7 @@ std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
99104
100105torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber (
101106 const at::Tensor& tensor, const torch::lazy::BackendDevice& device) {
107+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber" << std::endl;
102108 if (tensor.unsafeGetTensorImpl ()->is_wrapped_number () ||
103109 (tensor.dim () == 0 && tensor.numel () == 1 )) {
104110 return torch_xla::bridge::GetOrCreateXlaTensor (tensor, device);
@@ -109,6 +115,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
109115
110116XLATensorPtr GetOrCreateXlaTensor (const at::Tensor& tensor,
111117 const torch::lazy::BackendDevice& device) {
118+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
112119 if (!tensor.defined ()) {
113120 return XLATensorPtr ();
114121 }
@@ -122,6 +129,7 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
122129
123130XLATensorPtr GetOrCreateXlaTensor (const c10::optional<at::Tensor>& tensor,
124131 const torch::lazy::BackendDevice& device) {
132+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
125133 if (!IsDefined (tensor)) {
126134 return XLATensorPtr ();
127135 }
@@ -133,6 +141,7 @@ XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
133141std::vector<XLATensorPtr> GetOrCreateXlaTensors (
134142 absl::Span<const at::Tensor> tensors,
135143 const torch::lazy::BackendDevice& device) {
144+ std::cout << " WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors" << std::endl;
136145 std::vector<XLATensorPtr> xla_tensors;
137146 for (const at::Tensor& tensor : tensors) {
138147 xla_tensors.push_back (bridge::GetOrCreateXlaTensor (tensor, device));
@@ -141,6 +150,7 @@ std::vector<XLATensorPtr> GetOrCreateXlaTensors(
141150}
142151
143152std::vector<at::Tensor> XlaCreateTensorList (const at::ITensorListRef& tensors) {
153+ std::cout << " WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList" << std::endl;
144154 std::vector<at::Tensor> aten_xla_tensors (tensors.size ());
145155 std::vector<XLATensorPtr> xla_tensors;
146156 // We need to separate out the defined tensors first, GetXlaTensor() doesn't
@@ -179,6 +189,7 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
179189
180190std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList (
181191 const std::vector<c10::optional<at::Tensor>>& tensors) {
192+ std::cout << " WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList" << std::endl;
182193 std::vector<c10::optional<at::Tensor>> opt_aten_xla_tensors (tensors.size ());
183194 std::vector<at::Tensor> materialized_tensors;
184195 std::vector<bool > to_translate (tensors.size ());
@@ -202,6 +213,7 @@ std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList(
202213void XlaUpdateTensors (absl::Span<const at::Tensor> dest_xla_tensors,
203214 absl::Span<const at::Tensor> source_cpu_tensors,
204215 absl::Span<const size_t > indices) {
216+ std::cout << " WONJOO: at aten_xla_bridge.cpp, XlaUpdateTensors" << std::endl;
205217 for (auto index : indices) {
206218 at::Tensor dest = dest_xla_tensors.at (index);
207219 at::Tensor source = source_cpu_tensors.at (index);
@@ -332,6 +344,7 @@ c10::Device GetCurrentAtenDevice() {
332344
333345at::Tensor XlaToAtenTensor (XLATensorPtr xla_tensor,
334346 const at::TensorOptions& tensor_options) {
347+ std::cout << " WONJOO: at aten_xla_bridge.cpp, XlaToAtenTensor" << std::endl;
335348 if (tensor_options.has_device ()) {
336349 XLA_CHECK_NE (tensor_options.device ().type (), at::kXLA );
337350 }
@@ -343,6 +356,7 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
343356}
344357
345358at::Tensor AtenFromXlaTensor (XLATensorPtr xla_tensor) {
359+ std::cout << " WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensor" << std::endl;
346360 if (xla_tensor) {
347361 auto out =
348362 at::Tensor (c10::make_intrusive<XLATensorImpl>(std::move (xla_tensor)));
@@ -364,6 +378,7 @@ at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
364378
365379std::vector<at::Tensor> AtenFromXlaTensors (
366380 absl::Span<const XLATensorPtr> xla_tensors) {
381+ std::cout << " WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors" << std::endl;
367382 std::vector<at::Tensor> tensors;
368383 tensors.reserve (xla_tensors.size ());
369384 for (auto & tensor : xla_tensors) {
@@ -375,6 +390,7 @@ std::vector<at::Tensor> AtenFromXlaTensors(
375390at::Tensor CreateXlaTensor (
376391 at::Tensor tensor,
377392 const c10::optional<torch::lazy::BackendDevice>& device) {
393+ std::cout << " WONJOO: at aten_xla_bridge.cpp, CreateXlaTensor" << std::endl;
378394 if (tensor.defined () && device) {
379395 XLATensorPtr xla_tensor = XLATensor::Create (std::move (tensor), *device);
380396 tensor = AtenFromXlaTensor (xla_tensor);
@@ -385,6 +401,7 @@ at::Tensor CreateXlaTensor(
385401std::vector<at::Tensor> CreateXlaTensors (
386402 const std::vector<at::Tensor>& tensors,
387403 const c10::optional<torch::lazy::BackendDevice>& device) {
404+ std::cout << " WONJOO: at aten_xla_bridge.cpp, CreateXlaTensors" << std::endl;
388405 std::vector<at::Tensor> xtensors;
389406 for (auto & tensor : tensors) {
390407 xtensors.push_back (CreateXlaTensor (tensor, device));
0 commit comments