diff --git a/include/openPMD/Iteration.hpp b/include/openPMD/Iteration.hpp index 3c7fffc545..fd1f14cd00 100644 --- a/include/openPMD/Iteration.hpp +++ b/include/openPMD/Iteration.hpp @@ -420,4 +420,24 @@ inline T Iteration::dt() const { return this->readFloatingpoint("dt"); } + +/** + * @brief Subclass of Iteration that knows its own index withing the containing + * Series. + */ +class IndexedIteration : public Iteration +{ + friend class SeriesIterator; + friend class WriteIterations; + +public: + using index_t = Iteration::IterationIndex_t; + index_t const iterationIndex; + +private: + template + IndexedIteration(Iteration_t &&it, index_t index) + : Iteration(std::forward(it)), iterationIndex(index) + {} +}; } // namespace openPMD diff --git a/include/openPMD/ReadIterations.hpp b/include/openPMD/ReadIterations.hpp index c6a1e4fc36..c381cdd62b 100644 --- a/include/openPMD/ReadIterations.hpp +++ b/include/openPMD/ReadIterations.hpp @@ -31,26 +31,6 @@ namespace openPMD { -/** - * @brief Subclass of Iteration that knows its own index withing the containing - * Series. - */ -class IndexedIteration : public Iteration -{ - friend class SeriesIterator; - -public: - using iterations_t = decltype(internal::SeriesData::iterations); - using index_t = iterations_t::key_type; - index_t const iterationIndex; - -private: - template - IndexedIteration(Iteration_t &&it, index_t index) - : Iteration(std::forward(it)), iterationIndex(index) - {} -}; - class SeriesIterator { using iteration_index_t = IndexedIteration::index_t; diff --git a/include/openPMD/WriteIterations.hpp b/include/openPMD/WriteIterations.hpp index 3099af7025..7c457e7cfe 100644 --- a/include/openPMD/WriteIterations.hpp +++ b/include/openPMD/WriteIterations.hpp @@ -87,5 +87,10 @@ class WriteIterations public: mapped_type &operator[](key_type const &key); mapped_type &operator[](key_type &&key); + + /** + * Return the iteration that is currently being written to, if it exists. + */ + std::optional currentIteration(); }; } // namespace openPMD diff --git a/src/WriteIterations.cpp b/src/WriteIterations.cpp index 2bc34f0416..f5e976a6f4 100644 --- a/src/WriteIterations.cpp +++ b/src/WriteIterations.cpp @@ -69,13 +69,17 @@ WriteIterations::mapped_type &WriteIterations::operator[](key_type &&key) "[WriteIterations] Trying to access after closing Series."); } auto &s = shared->value(); - if (s.currentlyOpen.has_value()) + auto lastIteration = currentIteration(); + if (lastIteration.has_value()) { - auto lastIterationIndex = s.currentlyOpen.value(); - auto &lastIteration = s.iterations.at(lastIterationIndex); - if (lastIterationIndex != key && !lastIteration.closed()) + auto lastIteration_v = lastIteration.value(); + if (lastIteration_v.iterationIndex == key) { - lastIteration.close(); + return s.iterations.at(std::move(key)); + } + else + { + lastIteration_v.close(); // continue below } } s.currentlyOpen = key; @@ -87,4 +91,24 @@ WriteIterations::mapped_type &WriteIterations::operator[](key_type &&key) } return res; } + +std::optional WriteIterations::currentIteration() +{ + if (!shared || !shared->has_value()) + { + return std::nullopt; + } + auto &s = shared->value(); + if (!s.currentlyOpen.has_value()) + { + return std::nullopt; + } + Iteration ¤tIteration = s.iterations.at(s.currentlyOpen.value()); + if (currentIteration.closed()) + { + return std::nullopt; + } + return std::make_optional( + IndexedIteration(currentIteration, s.currentlyOpen.value())); +} } // namespace openPMD diff --git a/src/binding/python/Iteration.cpp b/src/binding/python/Iteration.cpp index 0ac290f7ff..59a9322039 100644 --- a/src/binding/python/Iteration.cpp +++ b/src/binding/python/Iteration.cpp @@ -63,8 +63,20 @@ void init_Iteration(py::module &m) "dt", &Iteration::dt, &Iteration::setDt) .def_property( "time_unit_SI", &Iteration::timeUnitSI, &Iteration::setTimeUnitSI) - .def("open", &Iteration::open) - .def("close", &Iteration::close, py::arg("flush") = true) + .def( + "open", + [](Iteration &it) { + py::gil_scoped_release release; + return it.open(); + }) + .def( + "close", + /* + * Cannot release the GIL here since Python buffers might be + * accessed in deferred tasks + */ + &Iteration::close, + py::arg("flush") = true) // TODO remove in future versions (deprecated) .def("set_time", &Iteration::setTime) diff --git a/src/binding/python/Series.cpp b/src/binding/python/Series.cpp index cdff83fd43..8874c21e43 100644 --- a/src/binding/python/Series.cpp +++ b/src/binding/python/Series.cpp @@ -53,24 +53,90 @@ struct openPMD_PyMPICommObject using openPMD_PyMPIIntracommObject = openPMD_PyMPICommObject; #endif +struct SeriesIteratorPythonAdaptor : SeriesIterator +{ + SeriesIteratorPythonAdaptor(SeriesIterator it) + : SeriesIterator(std::move(it)) + {} + + /* + * Python iterators are weird and call `__next__()` already for getting the + * first element. + * In that case, no `operator++()` must be called... + */ + bool first_iteration = true; +}; + void init_Series(py::module &m) { py::class_(m, "WriteIterations") .def( "__getitem__", [](WriteIterations writeIterations, Series::IterationIndex_t key) { + auto lastIteration = writeIterations.currentIteration(); + if (lastIteration.has_value() && + lastIteration.value().iterationIndex != key) + { + // this must happen under the GIL + lastIteration.value().close(); + } + py::gil_scoped_release release; return writeIterations[key]; }, // copy + keepalive - py::return_value_policy::copy); + py::return_value_policy::copy) + .def( + "current_iteration", + &WriteIterations::currentIteration, + "Return the iteration that is currently being written to, if it " + "exists."); py::class_(m, "IndexedIteration") .def_readonly("iteration_index", &IndexedIteration::iterationIndex); + + py::class_(m, "SeriesIterator") + .def( + "__next__", + [](SeriesIteratorPythonAdaptor &iterator) { + if (iterator == SeriesIterator::end()) + { + throw py::stop_iteration(); + } + /* + * Closing the iteration must happen under the GIL lock since + * Python buffers might be accessed + */ + if (!iterator.first_iteration) + { + if (!(*iterator).closed()) + { + (*iterator).close(); + } + py::gil_scoped_release release; + ++iterator; + } + iterator.first_iteration = false; + if (iterator == SeriesIterator::end()) + { + throw py::stop_iteration(); + } + else + { + return *iterator; + } + } + + ); + py::class_(m, "ReadIterations") .def( "__iter__", [](ReadIterations &readIterations) { - return py::make_iterator( - readIterations.begin(), readIterations.end()); + // Simple iterator implementation: + // But we need to release the GIL inside + // SeriesIterator::operator++, so manually it is + // return py::make_iterator( + // readIterations.begin(), readIterations.end()); + return SeriesIteratorPythonAdaptor(readIterations.begin()); }, // keep handle alive while iterator exists py::keep_alive<0, 1>()); @@ -78,7 +144,12 @@ void init_Series(py::module &m) py::class_(m, "Series") .def( - py::init(), + py::init([](std::string const &filepath, + Access at, + std::string const &options) { + py::gil_scoped_release release; + return new Series(filepath, at, options); + }), py::arg("filepath"), py::arg("access"), py::arg("options") = "{}") @@ -145,6 +216,7 @@ void init_Series(py::module &m) "(Mismatched MPI at compile vs. runtime?)"); } + py::gil_scoped_release release; return new Series(filepath, at, *mpiCommPtr, options); }), py::arg("filepath"), @@ -232,7 +304,13 @@ this method. py::return_value_policy::reference, // garbage collection: return value must be freed before Series py::keep_alive<1, 0>()) - .def("read_iterations", &Series::readIterations, py::keep_alive<0, 1>()) + .def( + "read_iterations", + [](Series &s) { + py::gil_scoped_release release; + return s.readIterations(); + }, + py::keep_alive<0, 1>()) .def( "write_iterations", &Series::writeIterations,