diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index 1d097cfd989..ffc45ecc960 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include @@ -17,7 +18,6 @@ #include #include #include -#include #include #include @@ -69,8 +69,10 @@ FileDataLoader::~FileDataLoader() { // file_name_ can be nullptr if this instance was moved from, but freeing a // null pointer is safe. std::free(const_cast(file_name_)); - // fd_ can be -1 if this instance was moved from, but closing a negative fd is - // safe (though it will return an error). + // fd_ can be -1 if this instance was moved from. + if (fd_ == -1) { + return; + } ::close(fd_); } diff --git a/extension/data_loader/pread.h b/extension/data_loader/pread.h new file mode 100644 index 00000000000..765397f3e85 --- /dev/null +++ b/extension/data_loader/pread.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file ensures that a pread-compatible function is defined in the global namespace for windows and posix environments. + +#pragma once + +#ifndef _WIN32 +#include +#else +#include // For ssize_t. +#include + +#include +// To avoid conflicts with std::numeric_limits::max() in +// file_data_loader.cpp. +#undef max + +inline ssize_t pread(int fd, void* buf, size_t nbytes, size_t offset) { + OVERLAPPED overlapped; /* The offset for ReadFile. */ + memset(&overlapped, 0, sizeof(overlapped)); + overlapped.Offset = offset; + overlapped.OffsetHigh = offset >> 32; + + BOOL result; /* The result of ReadFile. */ + DWORD bytes_read; /* The number of bytes read. */ + HANDLE file = (HANDLE)_get_osfhandle(fd); + + result = ReadFile(file, buf, nbytes, &bytes_read, &overlapped); + DWORD error = GetLastError(); + if (!result) { + if (error == ERROR_IO_PENDING) { + result = GetOverlappedResult(file, &overlapped, &bytes_read, TRUE); + if (!result) { + error = GetLastError(); + } + } + } + if (!result) { + // Translate error into errno. + switch (error) { + case ERROR_HANDLE_EOF: + errno = 0; + break; + default: + errno = EIO; + break; + } + return -1; + } + return bytes_read; +} + +#endif diff --git a/extension/data_loader/targets.bzl b/extension/data_loader/targets.bzl index 4886df03a76..ceb92d7ec7f 100644 --- a/extension/data_loader/targets.bzl +++ b/extension/data_loader/targets.bzl @@ -40,6 +40,7 @@ def define_common_targets(): runtime.cxx_library( name = "file_data_loader", srcs = ["file_data_loader.cpp"], + headers = ["pread.h"], exported_headers = ["file_data_loader.h"], visibility = [ "//executorch/test/...",