Skip to content

Commit 0ab17c3

Browse files
author
Jorge Pineda
committed
[ET-VK] Save and load VkPipelineCache file if path is specified
Pull Request resolved: #3546 ## Context Pipeline creation involves the compilation of shader SPIR-V code into machine-specific code. This makes the application's first model-load via the `Program::load_method()` ET-API very slow, due to the creation of compute pipelines via the `vkCreateComputePipelines()` VK-API. To amortize that cost, Vulkan offers a [Compute Pipeline Cache API](https://docs.vulkan.org/guide/latest/pipeline_cache.html). Following [this Vulkan example](https://github.com/KhronosGroup/Vulkan-Samples/tree/main/samples/performance/pipeline_cache), we can (A) retrieve the compiled machine-specific code saving it to a file and (B) load it to a file next time. For an internal model executing on a resource-constrained device, this improves model-load time from ~1200ms to ~500ms. ## This change We implement both (A)+(B) ET-VK logic. Note that these changes are actually no-op unless you initialize the `pipeline_cache_file_path` manually. The expectation is for the client application to specify the file path of their pipeline cache data if they want to leverage this optimization. In a future ET-wide change, we will expose the file_path config parameter to the ET-API. ghstack-source-id: 225755565 Differential Revision: [D57085276](https://our.internmc.facebook.com/intern/diff/D57085276/)
1 parent 58922d1 commit 0ab17c3

6 files changed

Lines changed: 66 additions & 10 deletions

File tree

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ DeviceHandle::~DeviceHandle() {
292292
Adapter::Adapter(
293293
VkInstance instance,
294294
PhysicalDevice physical_device,
295-
const uint32_t num_queues)
295+
const uint32_t num_queues,
296+
const std::string& pipeline_cache_file_path)
296297
: queue_usage_mutex_{},
297298
physical_device_(std::move(physical_device)),
298299
queues_{},
@@ -307,7 +308,7 @@ Adapter::Adapter(
307308
shader_layout_cache_(device_.handle_),
308309
shader_cache_(device_.handle_),
309310
pipeline_layout_cache_(device_.handle_),
310-
compute_pipeline_cache_(device_.handle_),
311+
compute_pipeline_cache_(device_.handle_, pipeline_cache_file_path),
311312
sampler_cache_(device_.handle_),
312313
vma_(instance_, physical_device_.handle, device_.handle_) {}
313314

backends/vulkan/runtime/api/Adapter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ class Adapter final {
101101
explicit Adapter(
102102
VkInstance instance,
103103
PhysicalDevice physical_device,
104-
const uint32_t num_queues);
104+
const uint32_t num_queues,
105+
const std::string& pipeline_cache_file_path);
105106

106107
Adapter(const Adapter&) = delete;
107108
Adapter& operator=(const Adapter&) = delete;

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/api/Pipeline.h>
1010

11+
#include <fstream>
12+
1113
namespace vkcompute {
1214
namespace api {
1315

@@ -358,17 +360,24 @@ void PipelineLayoutCache::purge() {
358360
// ComputePipelineCache
359361
//
360362

361-
ComputePipelineCache::ComputePipelineCache(VkDevice device)
363+
ComputePipelineCache::ComputePipelineCache(
364+
VkDevice device,
365+
const std::string& file_path)
362366
: cache_mutex_{},
363367
device_(device),
364368
pipeline_cache_{VK_NULL_HANDLE},
365-
cache_{} {
366-
const VkPipelineCacheCreateInfo pipeline_cache_create_info{
369+
cache_{},
370+
file_path_(file_path) {
371+
VkPipelineCacheCreateInfo pipeline_cache_create_info{};
372+
373+
auto buffer = load_cache();
374+
375+
pipeline_cache_create_info = {
367376
VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
368377
nullptr, // pNext
369378
0u, // flags
370-
0u, // initialDataSize
371-
nullptr, // pInitialData
379+
buffer.size(), // initialDataSize
380+
buffer.data(), // pInitialData
372381
};
373382

374383
VK_CHECK(vkCreatePipelineCache(
@@ -392,6 +401,9 @@ ComputePipelineCache::~ComputePipelineCache() {
392401
if (VK_NULL_HANDLE == pipeline_cache_) {
393402
return;
394403
}
404+
405+
save_cache();
406+
395407
vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
396408
pipeline_cache_ = VK_NULL_HANDLE;
397409
}
@@ -416,5 +428,37 @@ void ComputePipelineCache::purge() {
416428
cache_.clear();
417429
}
418430

431+
std::vector<char> ComputePipelineCache::load_cache() {
432+
// Return if path is not specified; this means the optimization is disabled
433+
if (file_path_.empty()) {
434+
return {};
435+
}
436+
437+
// Return if file doesn't exist; this is expected on the first model-load
438+
std::ifstream file(file_path_, std::ios::binary | std::ios::ate);
439+
if (file.fail()) {
440+
return {};
441+
}
442+
443+
auto size = file.tellg();
444+
file.seekg(0, std::ios::beg);
445+
446+
std::vector<char> buffer(size);
447+
file.read(buffer.data(), size);
448+
449+
return buffer;
450+
}
451+
452+
void ComputePipelineCache::save_cache() {
453+
size_t size{};
454+
vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr);
455+
456+
std::vector<char> buffer(size);
457+
vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data());
458+
459+
std::ofstream file(file_path_, std::ios::binary);
460+
file.write(buffer.data(), buffer.size());
461+
}
462+
419463
} // namespace api
420464
} // namespace vkcompute

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class PipelineLayoutCache final {
216216

217217
class ComputePipelineCache final {
218218
public:
219-
explicit ComputePipelineCache(VkDevice device);
219+
explicit ComputePipelineCache(VkDevice device, const std::string& file_path);
220220

221221
ComputePipelineCache(const ComputePipelineCache&) = delete;
222222
ComputePipelineCache& operator=(const ComputePipelineCache&) = delete;
@@ -266,13 +266,17 @@ class ComputePipelineCache final {
266266
};
267267

268268
private:
269+
std::vector<char> load_cache();
270+
void save_cache();
271+
269272
// Multiple threads could potentially be adding entries into the cache, so use
270273
// a mutex to manage access
271274
std::mutex cache_mutex_;
272275

273276
VkDevice device_;
274277
VkPipelineCache pipeline_cache_;
275278
std::unordered_map<Key, Value, Hasher> cache_;
279+
const std::string file_path_;
276280

277281
public:
278282
VkPipeline retrieve(const Key&);

backends/vulkan/runtime/api/Runtime.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,14 @@ std::unique_ptr<Runtime> init_global_vulkan_runtime() {
253253
#endif /* VULKAN_DEBUG */
254254
const bool init_default_device = true;
255255
const uint32_t num_requested_queues = 1; // TODO: raise this value
256+
const std::string pipeline_cache_file_path = ""; // TODO: expose to client
256257

257258
const RuntimeConfiguration default_config{
258259
enable_validation_messages,
259260
init_default_device,
260261
AdapterSelector::First,
261262
num_requested_queues,
263+
pipeline_cache_file_path,
262264
};
263265

264266
try {
@@ -351,7 +353,10 @@ uint32_t Runtime::create_adapter(const Selector& selector) {
351353
// Otherwise, create an adapter for the selected physical device
352354
adapter_i = utils::safe_downcast<int32_t>(adapters_.size());
353355
adapters_.emplace_back(new Adapter(
354-
instance_, device_mapping.first, config_.num_requested_queues));
356+
instance_,
357+
device_mapping.first,
358+
config_.num_requested_queues,
359+
config_.pipeline_cache_file_path));
355360
device_mapping.second = adapter_i;
356361

357362
return adapter_i;

backends/vulkan/runtime/api/Runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct RuntimeConfiguration final {
3939
bool init_default_device;
4040
AdapterSelector default_selector;
4141
uint32_t num_requested_queues;
42+
std::string pipeline_cache_file_path;
4243
};
4344

4445
class Runtime final {

0 commit comments

Comments
 (0)