diff --git a/sdk_v2/cpp/CMakeLists.txt b/sdk_v2/cpp/CMakeLists.txt index 671c7648..dcd191b4 100644 --- a/sdk_v2/cpp/CMakeLists.txt +++ b/sdk_v2/cpp/CMakeLists.txt @@ -148,8 +148,11 @@ set(FOUNDRY_LOCAL_SOURCES src/inferencing/generative/chat/chat_session.cc src/inferencing/generative/chat/chat_template.cc src/configuration.cc + src/download/blob_download_state.cc src/download/blob_downloader.cc + src/download/cross_process_file_lock.cc src/download/download_manager.cc + src/download/file_writer.cc src/download/inference_model_writer.cc src/download/model_registry_client.cc src/ep_detection/cuda_ep_bootstrapper.cc diff --git a/sdk_v2/cpp/src/download/blob_download_state.cc b/sdk_v2/cpp/src/download/blob_download_state.cc new file mode 100644 index 00000000..329109d7 --- /dev/null +++ b/sdk_v2/cpp/src/download/blob_download_state.cc @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/blob_download_state.h" +#include "logger.h" + +#include +#include +#include +#include +#include + +namespace fl { + +namespace { + +constexpr const char* kStateFileExtension = ".dlstate"; + +// On-disk format. Scalar fields use host byte order (little-endian on every +// target we build for); see WriteNative/ReadNative below. The bitmap suffix is +// a raw byte copy and is endian-agnostic. +// bytes | field +// -------|-------------------------------------------------------- +// 0..3 | magic "FLDS" +// 4 | version (currently 1) +// 5..12 | blob_size (int64) +// 13..16 | chunk_size (int32) +// 17..20 | total_chunks (int32) +// 21..24 | bitmap_byte_aligned_start (int32) +// 25..28 | highest_completed_chunk (int32) +// 29..32 | completed_count (int32) +// 33..40 | last_modified_unix_ms (int64) +// 41..44 | trunc_bitmap_byte_len (uint32) +// 45.. | trunc_bitmap_byte_len bytes of bitmap data, copied directly out of +// full_completion_bitmap starting at the byte offset implied by +// bitmap_byte_aligned_start. +constexpr char kMagic[4] = {'F', 'L', 'D', 'S'}; +constexpr uint8_t kVersion = 1; + +constexpr int32_t kBitsPerWord = 64; + +// Serialize a scalar field in host byte order. Every target we build for +// (x64 / arm64) is little-endian, so the on-disk layout is little-endian in +// practice. +template +void WriteNative(std::ostream& out, T value) { + static_assert(std::is_trivially_copyable_v); + out.write(reinterpret_cast(&value), sizeof(T)); +} + +template +bool ReadNative(std::istream& in, T& out_value) { + static_assert(std::is_trivially_copyable_v); + in.read(reinterpret_cast(&out_value), sizeof(T)); + return static_cast(in); +} + +int64_t NowUnixMs() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +} // namespace + +std::filesystem::path BlobDownloadState::GetStateFilePath(const std::filesystem::path& local_file_path) { + auto p = local_file_path; + p += kStateFileExtension; + return p; +} + +std::unique_ptr BlobDownloadState::CreateNew(std::string blob_name, + std::filesystem::path local_file_path, + int64_t blob_size, + int32_t chunk_size, + int32_t total_chunks) { + auto state = std::make_unique(); + state->blob_name = std::move(blob_name); + state->local_file_path = local_file_path.string(); + state->blob_size = blob_size; + state->chunk_size = chunk_size; + state->total_chunks = total_chunks; + state->bitmap_byte_aligned_start = 0; + state->highest_completed_chunk = -1; + state->completed_count = 0; + state->last_modified_unix_ms = NowUnixMs(); + auto words = static_cast((total_chunks + kBitsPerWord - 1) / kBitsPerWord); + state->full_completion_bitmap.assign(words, 0); + return state; +} + +std::unique_ptr BlobDownloadState::LoadState(std::string blob_name, + std::filesystem::path local_file_path, + int64_t expected_blob_size, + int32_t expected_chunk_size, + int32_t expected_total_chunks, + ILogger& logger) { + auto state_path = GetStateFilePath(local_file_path); + std::error_code ec; + if (!std::filesystem::exists(state_path, ec)) { + return nullptr; + } + + std::ifstream in(state_path, std::ios::binary); + if (!in) { + logger.Log(LogLevel::Warning, "Could not open download state file: " + state_path.string()); + return nullptr; + } + + char magic[4]{}; + in.read(magic, 4); + uint8_t version = 0; + if (!in || std::memcmp(magic, kMagic, 4) != 0 || !ReadNative(in, version) || version != kVersion) { + logger.Log(LogLevel::Warning, + "Download state file " + state_path.string() + " has unexpected magic/version; ignoring"); + return nullptr; + } + + int64_t blob_size = 0; + int32_t chunk_size = 0; + int32_t total_chunks = 0; + int32_t bitmap_byte_aligned_start = 0; + int32_t highest_completed_chunk = 0; + int32_t completed_count = 0; + int64_t last_modified_unix_ms = 0; + uint32_t trunc_len = 0; + if (!ReadNative(in, blob_size) || !ReadNative(in, chunk_size) || !ReadNative(in, total_chunks) || + !ReadNative(in, bitmap_byte_aligned_start) || !ReadNative(in, highest_completed_chunk) || + !ReadNative(in, completed_count) || !ReadNative(in, last_modified_unix_ms) || !ReadNative(in, trunc_len)) { + logger.Log(LogLevel::Warning, "Download state header truncated: " + state_path.string()); + return nullptr; + } + + // Sanity / compatibility checks. + if (blob_size != expected_blob_size || chunk_size != expected_chunk_size || + total_chunks != expected_total_chunks) { + logger.Log(LogLevel::Information, + "Download state for " + state_path.string() + + " is incompatible with current blob layout; starting fresh"); + return nullptr; + } + if (bitmap_byte_aligned_start < 0 || bitmap_byte_aligned_start % 8 != 0 || + bitmap_byte_aligned_start > total_chunks || completed_count < 0 || + completed_count > total_chunks || highest_completed_chunk < -1 || + highest_completed_chunk >= total_chunks) { + logger.Log(LogLevel::Warning, "Download state header values out of range: " + state_path.string()); + return nullptr; + } + + auto words_total = static_cast((total_chunks + kBitsPerWord - 1) / kBitsPerWord); + std::vector bitmap(words_total, 0); + + // The prefix of fully-completed chunks below bitmap_byte_aligned_start is + // implied — fill those bits. + size_t implicit_full_words = static_cast(bitmap_byte_aligned_start) / kBitsPerWord; + for (size_t i = 0; i < implicit_full_words && i < bitmap.size(); ++i) { + bitmap[i] = ~uint64_t{0}; + } + // Any remaining "implicit" bits inside a partial word (between + // implicit_full_words*64 and bitmap_byte_aligned_start). + if (size_t partial_bits = static_cast(bitmap_byte_aligned_start) % kBitsPerWord; + partial_bits > 0 && implicit_full_words < bitmap.size()) { + bitmap[implicit_full_words] |= (uint64_t{1} << partial_bits) - 1; + } + + if (trunc_len > 0) { + // Copy serialized bytes directly into the bitmap starting at the byte + // position implied by bitmap_byte_aligned_start. + size_t byte_offset = static_cast(bitmap_byte_aligned_start) / 8; + auto* dest = reinterpret_cast(bitmap.data()) + byte_offset; + auto dest_capacity = bitmap.size() * sizeof(uint64_t) - byte_offset; + if (trunc_len > dest_capacity) { + logger.Log(LogLevel::Warning, + "Download state bitmap length exceeds expected capacity: " + state_path.string()); + return nullptr; + } + in.read(reinterpret_cast(dest), trunc_len); + if (!in) { + logger.Log(LogLevel::Warning, + "Download state bitmap payload truncated: " + state_path.string()); + return nullptr; + } + } + + auto state = std::make_unique(); + state->blob_name = std::move(blob_name); + state->local_file_path = local_file_path.string(); + state->blob_size = blob_size; + state->chunk_size = chunk_size; + state->total_chunks = total_chunks; + state->bitmap_byte_aligned_start = bitmap_byte_aligned_start; + state->highest_completed_chunk = highest_completed_chunk; + state->completed_count = completed_count; + state->last_modified_unix_ms = last_modified_unix_ms; + state->full_completion_bitmap = std::move(bitmap); + + logger.Log(LogLevel::Information, + "Loaded download state " + state_path.string() + ": " + + std::to_string(completed_count) + "/" + std::to_string(total_chunks) + + " chunks already done"); + return state; +} + +int64_t BlobDownloadState::CalculateDownloadedSize() const noexcept { + int64_t bytes = static_cast(completed_count) * chunk_size; + // If the final chunk is partial and was completed, adjust the overcount. + if (highest_completed_chunk == total_chunks - 1 && chunk_size > 0) { + auto remainder = blob_size % chunk_size; + if (remainder != 0) { + bytes -= (chunk_size - remainder); + } + } + return bytes; +} + +bool BlobDownloadState::IsChunkComplete(int32_t chunk_idx) const noexcept { + if (chunk_idx < 0 || chunk_idx >= total_chunks) { + return false; + } + if (chunk_idx < bitmap_byte_aligned_start) { + // Below the truncation point — implicitly complete. + return true; + } + auto word_idx = static_cast(chunk_idx) / kBitsPerWord; + auto bit_idx = static_cast(chunk_idx) % kBitsPerWord; + if (word_idx >= full_completion_bitmap.size()) { + return false; + } + return (full_completion_bitmap[word_idx] & (uint64_t{1} << bit_idx)) != 0; +} + +void BlobDownloadState::MarkChunkComplete(int32_t chunk_idx) { + if (chunk_idx < 0 || chunk_idx >= total_chunks) { + return; + } + if (IsChunkComplete(chunk_idx)) { + return; + } + if (chunk_idx > highest_completed_chunk) { + highest_completed_chunk = chunk_idx; + } + auto word_idx = static_cast(chunk_idx) / kBitsPerWord; + auto bit_idx = static_cast(chunk_idx) % kBitsPerWord; + full_completion_bitmap[word_idx] |= (uint64_t{1} << bit_idx); + ++completed_count; +} + +std::vector BlobDownloadState::GetPendingChunks() const { + std::vector pending; + pending.reserve(static_cast(total_chunks - completed_count)); + for (int32_t i = bitmap_byte_aligned_start; i < total_chunks; ++i) { + if (!IsChunkComplete(i)) { + pending.push_back(i); + } + } + return pending; +} + +bool BlobDownloadState::SaveState(ILogger& logger) { + // Advance bitmap_byte_aligned_start past any words that are now all 1s, so + // the next save serializes only the unfinished tail. + // Find the first word that is not fully complete. Every word below it is + // implicitly complete and need not be serialized again. + size_t word_idx = static_cast(bitmap_byte_aligned_start) / kBitsPerWord; + while (word_idx < full_completion_bitmap.size() && + full_completion_bitmap[word_idx] == ~uint64_t{0}) { + ++word_idx; + } + int32_t new_start; + if (word_idx < full_completion_bitmap.size()) { + // Within the first not-fully-set word, advance to the lowest 0 bit. Derive + // the absolute chunk index from the word base (word_idx * 64), NOT by + // accumulating 64 per word onto the (possibly unaligned) previous start — + // the latter overshoots by (bitmap_byte_aligned_start % 64) and would mark + // never-downloaded chunks complete on reload. Round down to a byte boundary + // so reload-then-resume re-reads on a clean alignment. + uint64_t inverted = ~full_completion_bitmap[word_idx]; + int trailing_zero = 0; + while (trailing_zero < kBitsPerWord && ((inverted >> trailing_zero) & 1) == 0) { + ++trailing_zero; + } + new_start = static_cast(word_idx) * kBitsPerWord + trailing_zero; + } else { + // Every word is fully complete. + new_start = total_chunks; + } + new_start = (new_start / 8) * 8; + if (new_start > total_chunks) { + new_start = (total_chunks / 8) * 8; + } + if (new_start > bitmap_byte_aligned_start) { + bitmap_byte_aligned_start = new_start; + } + + last_modified_unix_ms = NowUnixMs(); + + auto state_path = GetStateFilePath(local_file_path); + auto tmp_path = state_path; + tmp_path += ".tmp"; + + // Compute the serialized bitmap payload: bytes from bitmap_byte_aligned_start + // up to (highest_completed_chunk + 1), rounded up to the nearest byte. + uint32_t trunc_len = 0; + if (highest_completed_chunk >= bitmap_byte_aligned_start) { + int32_t bit_count = highest_completed_chunk - bitmap_byte_aligned_start + 1; + trunc_len = static_cast((bit_count + 7) / 8); + } + size_t byte_offset = static_cast(bitmap_byte_aligned_start) / 8; + + { + std::ofstream out(tmp_path, std::ios::binary | std::ios::trunc); + if (!out) { + logger.Log(LogLevel::Error, "Failed to open download state tmp file: " + tmp_path.string()); + return false; + } + out.write(kMagic, 4); + WriteNative(out, kVersion); + WriteNative(out, blob_size); + WriteNative(out, chunk_size); + WriteNative(out, total_chunks); + WriteNative(out, bitmap_byte_aligned_start); + WriteNative(out, highest_completed_chunk); + WriteNative(out, completed_count); + WriteNative(out, last_modified_unix_ms); + WriteNative(out, trunc_len); + if (trunc_len > 0) { + auto* src = reinterpret_cast(full_completion_bitmap.data()) + byte_offset; + out.write(reinterpret_cast(src), trunc_len); + } + if (!out) { + logger.Log(LogLevel::Error, "Failed to write download state tmp file: " + tmp_path.string()); + return false; + } + } + + std::error_code ec; + std::filesystem::rename(tmp_path, state_path, ec); + if (ec) { + // std::filesystem::rename atomically replaces the destination on every + // platform we target (POSIX rename(2); Windows MoveFileExW with + // MOVEFILE_REPLACE_EXISTING). If it still fails, the cause is transient + // (e.g. a brief sharing violation on Windows or a flaky network FS) — + // do NOT delete state_path as a fallback; that loses the only intact + // copy of the resume bitmap. Instead, drop the tmp file and let the + // next SaveState call retry from the up-to-date in-memory state. + std::error_code rm_ec; + std::filesystem::remove(tmp_path, rm_ec); + logger.Log(LogLevel::Error, + "Failed to commit download state file: " + tmp_path.string() + " -> " + + state_path.string() + " (" + ec.message() + + "); previous state retained, will retry on next save"); + return false; + } + return true; +} + +void BlobDownloadState::DeleteState(const std::filesystem::path& local_file_path, ILogger& logger) { + auto state_path = GetStateFilePath(local_file_path); + std::error_code ec; + std::filesystem::remove(state_path, ec); + if (ec) { + logger.Log(LogLevel::Warning, + "Failed to delete download state file: " + state_path.string() + " (" + + ec.message() + ")"); + } +} + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/blob_download_state.h b/sdk_v2/cpp/src/download/blob_download_state.h new file mode 100644 index 00000000..362fed77 --- /dev/null +++ b/sdk_v2/cpp/src/download/blob_download_state.h @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fl { + +class ILogger; + +/// Per-blob download progress, persisted next to the data file as `.dlstate`. +/// +/// Each chunk completion flips a bit in `full_completion_bitmap`. On resume, +/// `GetPendingChunks` enumerates only chunks whose bits are still 0. +/// +/// The serialized form stores only the bitmap suffix starting at +/// `bitmap_byte_aligned_start` to `highest_completed_chunk`. +/// This keeps the on-disk state proportional to the *unfinished* +/// range, not the total file size. +/// +/// On-disk layout is a small fixed-width little-endian binary header followed +/// by the truncated bitmap bytes. +class BlobDownloadState { + public: + /// Identity of the blob (populated by caller; not serialized). + std::string blob_name; + std::string local_file_path; + + /// Fixed at first save; serialized for resume integrity checks. + int64_t blob_size = 0; + int32_t chunk_size = 0; + int32_t total_chunks = 0; + + /// Serialization marker (always a multiple of 8): chunks below this index are + /// complete and dropped from the sidecar's truncated bitmap. The in-memory + /// `full_completion_bitmap` still covers them. + int32_t bitmap_byte_aligned_start = 0; + + /// Highest chunk index completed so far. -1 if no chunks are done yet. + int32_t highest_completed_chunk = -1; + + /// Cached count for O(1) `IsComplete()`. + int32_t completed_count = 0; + + /// Unix epoch milliseconds; refreshed on every save. + int64_t last_modified_unix_ms = 0; + + /// One bit per chunk over the whole blob: chunk `i` lives in word `i / 64` at + /// bit `i % 64` (absolute indexing — the buffer always starts at chunk 0). + /// Sized for all `total_chunks` by `CreateNew`; `MarkChunkComplete` sets bits + /// without resizing. + std::vector full_completion_bitmap; + + /// Sidecar path for `local_file_path`. + static std::filesystem::path GetStateFilePath(const std::filesystem::path& local_file_path); + + /// Construct a fresh state for a new download. Bitmap sized for `total_chunks`. + static std::unique_ptr CreateNew(std::string blob_name, + std::filesystem::path local_file_path, + int64_t blob_size, + int32_t chunk_size, + int32_t total_chunks); + + /// Load existing state from `.dlstate`. Returns nullptr if + /// the file does not exist, is corrupted, or has incompatible + /// `blob_size` / `chunk_size` / `total_chunks` (caller-provided values are + /// authoritative — a mismatch means the blob has been reconfigured upstream + /// and the partial download is no longer valid). + /// `logger` receives diagnostics for corrupt/incompatible state files. Required: the + /// downloader always has a logger, so there is no optional/null case to handle. + static std::unique_ptr LoadState(std::string blob_name, + std::filesystem::path local_file_path, + int64_t expected_blob_size, + int32_t expected_chunk_size, + int32_t expected_total_chunks, + ILogger& logger); + + /// All chunks downloaded. + bool IsComplete() const noexcept { return completed_count == total_chunks; } + + /// Sum of bytes already written. Accounts for the final chunk being smaller + /// than `chunk_size` when blob_size is not chunk-aligned. + int64_t CalculateDownloadedSize() const noexcept; + + /// Whether `chunk_idx` is already marked complete. + bool IsChunkComplete(int32_t chunk_idx) const noexcept; + + /// Mark `chunk_idx` complete. Caller must hold the mutex when called from + /// concurrent worker tasks (use `mutex()` for that). Idempotent. + void MarkChunkComplete(int32_t chunk_idx); + + /// Enumerate chunks in [0, total_chunks) that are not yet complete. + std::vector GetPendingChunks() const; + + /// Atomically write current state to `.dlstate`. Returns true + /// on success; on failure it logs and returns false rather than throwing. Most + /// callers treat a failed periodic save as best-effort (the next save retries, + /// and resume just replays a few chunks); the initial pre-allocation save + /// treats false as fatal, since the "pre-allocated <=> sidecar present" + /// invariant depends on it. `logger` is required. + bool SaveState(ILogger& logger); + + /// Remove the sidecar; called on successful completion. + static void DeleteState(const std::filesystem::path& local_file_path, + ILogger& logger); + + /// Mutex protecting concurrent `MarkChunkComplete` / `SaveState` calls from + /// the chunk worker pool. + std::mutex& mutex() noexcept { return mutex_; } + + private: + mutable std::mutex mutex_; +}; + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/blob_downloader.cc b/sdk_v2/cpp/src/download/blob_downloader.cc index 1d5c2981..c5da2fbb 100644 --- a/sdk_v2/cpp/src/download/blob_downloader.cc +++ b/sdk_v2/cpp/src/download/blob_downloader.cc @@ -1,16 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "download/blob_downloader.h" +#include "download/blob_download_state.h" +#include "download/file_writer.h" #include "exception.h" +#include "logger.h" #include "util/path_safety.h" #include "util/string_utils.h" #include #include #include +#include #include #include #include +#include #include #include #include @@ -20,10 +25,34 @@ namespace fl { +namespace { + +/// Streaming buffer size used by the production chunk downloader. Matches the +/// 64 KB-ish granularity Stream.CopyTo uses in .NET, capping per-worker peak +/// memory at this many bytes regardless of chunk size. +constexpr size_t kStreamingBufferBytes = 64 * 1024; + +} // namespace + // ======================================================================== // AzureBlobDownloader — real Azure Storage SDK implementation // ======================================================================== +/// Per-blob shared state passed to the protected virtuals. Both members are +/// references to objects the orchestrator owns on the stack for the lifetime of +/// the download, so they are never null. `blob_client` is const because every +/// call routed through it (GetProperties / Download) is a const SDK operation. +/// `azure_ctx` is const here because the virtuals only *observe* cancellation +/// (IsCancelled, and handing the context to SDK reads); the orchestrator +/// initiates cancellation by calling Cancel() on the owning Context directly, +/// not through this view. +struct AzureBlobDownloader::ChunkContext { + const Azure::Storage::Blobs::BlobClient& blob_client; + const Azure::Core::Context& azure_ctx; +}; + +AzureBlobDownloader::AzureBlobDownloader(ILogger& logger) : logger_(logger) {} + std::vector AzureBlobDownloader::ListBlobs(const std::string& sas_uri) { try { auto container_client = Azure::Storage::Blobs::BlobContainerClient(sas_uri); @@ -45,6 +74,60 @@ std::vector AzureBlobDownloader::ListBlobs(const std::string& sas_ } } +int64_t AzureBlobDownloader::GetBlobSize(ChunkContext& ctx) { + auto props = ctx.blob_client.GetProperties({}, ctx.azure_ctx).Value; + return props.BlobSize; +} + +bool AzureBlobDownloader::IsCancellationRequested(ChunkContext& ctx) { + return ctx.azure_ctx.IsCancelled(); +} + +void AzureBlobDownloader::DownloadChunkStreaming( + ChunkContext& ctx, int64_t offset, int64_t size, std::vector& scratch, + const std::function& sink) { + Azure::Storage::Blobs::DownloadBlobOptions range_opts; + range_opts.Range = Azure::Core::Http::HttpRange{offset, size}; + auto result = ctx.blob_client.Download(range_opts, ctx.azure_ctx); + auto& body_stream = *result.Value.BodyStream; + + if (scratch.size() < kStreamingBufferBytes) { + scratch.resize(kStreamingBufferBytes); + } + + int64_t remaining = size; + while (remaining > 0) { + size_t to_read = static_cast(std::min(remaining, static_cast(scratch.size()))); + size_t got = body_stream.Read(scratch.data(), to_read, ctx.azure_ctx); + if (got == 0) { + // Zero-byte read before reaching `size` means the server closed early. + // Treat as a hard error rather than silently writing a truncated chunk. + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "short read from blob stream at offset " + std::to_string(offset) + ": got " + + std::to_string(size - remaining) + " of " + std::to_string(size) + " bytes"); + } + sink(scratch.data(), got); + remaining -= static_cast(got); + } +} + +namespace { + +/// Create (truncate to) a zero-byte file at `local_path`, throwing on failure. +/// +/// Used only for the empty-blob case below: a 0-length blob has no chunks to +/// stream, so there is nothing for `FileWriter::Open` to pre-allocate — we just +/// materialize the empty file. The chunked path's pre-allocation lives in `Open`. +void EnsureEmptyBlobFile(const std::string& local_path) { + std::ofstream f(local_path, std::ios::binary); + if (!f.is_open()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to create empty blob file: " + local_path); + } +} + +} // namespace + void AzureBlobDownloader::DownloadBlob(const std::string& sas_uri, const std::string& blob_name, const std::string& local_path, @@ -65,155 +148,235 @@ void AzureBlobDownloader::DownloadBlob(const std::string& sas_uri, auto container_client = Azure::Storage::Blobs::BlobContainerClient(sas_uri, client_options); auto blob_client = container_client.GetBlobClient(blob_name); - // Context provides cooperative cancellation across all SDK operations. - Azure::Core::Context ctx; + // Single shared Azure context for the whole blob; calling Cancel() on it + // propagates into every in-flight chunk read. + Azure::Core::Context azure_ctx; + // Internal cancel flag flipped by the orchestrator on first chunk failure + // or by external cancellation; checked by workers between iterations. + std::atomic internal_cancel{false}; - // Get blob size - auto props = blob_client.GetProperties({}, ctx).Value; - int64_t blob_size = props.BlobSize; + ChunkContext chunk_ctx{blob_client, azure_ctx}; - if (blob_size == 0) { - // Empty blob — just create the file - std::ofstream f(local_path, std::ios::binary); - if (!f.is_open()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to create empty blob file: " + local_path); - } + int64_t blob_size = GetBlobSize(chunk_ctx); + if (blob_size == 0) { + EnsureEmptyBlobFile(local_path); + BlobDownloadState::DeleteState(local_path, logger_); return; } // 2MB chunk size matching C# constexpr int64_t kChunkSize = 2 * 1024 * 1024; - int64_t num_chunks = (blob_size + kChunkSize - 1) / kChunkSize; + int32_t num_chunks = static_cast((blob_size + kChunkSize - 1) / kChunkSize); + + // Resume from existing sidecar if it matches the current blob layout. + auto state = BlobDownloadState::LoadState(blob_name, local_path, blob_size, + static_cast(kChunkSize), + num_chunks, logger_); + if (state) { + // Only trust the sidecar if the data file it describes is actually on disk + // at full size. If the data file was truncated or removed (e.g. an external + // cleanup) while the sidecar survived, the chunks it marks complete are gone: + // we would skip re-downloading them, Open() would recreate the file + // zero-filled, and the result would be a silently corrupt file. Discard the + // stale state and start fresh. + std::error_code data_ec; + auto data_size = std::filesystem::file_size(local_path, data_ec); + if (data_ec || data_size != static_cast(blob_size)) { + logger_.Log(LogLevel::Information, + "Resume sidecar for '" + local_path + + "' has no matching full-size data file; starting fresh"); + state.reset(); + } + } - // Pre-allocate the file to the full blob size. - // This lets concurrent chunk writes seek to their offset without a resize race. - { - std::ofstream f(local_path, std::ios::binary); - if (!f.is_open()) { + if (!state) { + state = BlobDownloadState::CreateNew(blob_name, local_path, blob_size, + static_cast(kChunkSize), num_chunks); + // Persist the sidecar now, before Open() pre-allocates the data file. + // IsDownloadNeeded treats "data file at full size + no sidecar" as a + // completed download and skips it. The periodic save below does not run + // until save_interval chunks are done (~16 MB), so a crash between + // pre-allocation and that first save would otherwise leave a full-size, + // mostly-empty file with no sidecar that the next run silently accepts as + // complete — serving zeros. Writing the sidecar up front upholds the + // invariant "pre-allocated but unfinished <=> sidecar present" — so if it + // can't be persisted we abort here, before Open() pre-allocates, rather + // than risk a full-size file a later run reads as complete. + if (!state->SaveState(logger_)) { FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to open blob file for pre-allocation: " + local_path); + "failed to persist initial download state for '" + local_path + "'"); } + } - f.seekp(blob_size - 1); - f.put('\0'); - f.close(); - if (f.fail()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to pre-allocate blob file: " + local_path + - " (size=" + std::to_string(blob_size) + ")"); + // Track cumulative bytes for progress reporting; seed with bytes already + // present on disk so percent stays monotonic across resume. + std::atomic bytes_completed{state->CalculateDownloadedSize()}; + if (bytes_written_cb && bytes_completed.load() > 0) { + bytes_written_cb(bytes_completed.load()); + } + + auto pending = state->GetPendingChunks(); + if (pending.empty()) { + // Already complete on disk — drop the sidecar. + BlobDownloadState::DeleteState(local_path, logger_); + if (bytes_written_cb) { + bytes_written_cb(blob_size); } + return; } - // Track cumulative bytes for progress reporting - std::atomic bytes_completed{0}; + // Open the file writer once for the whole download. Open() pre-allocates + // the file to blob_size if needed, preserving any existing bytes from a + // resume. Concurrent WriteAt calls to disjoint ranges are thread-safe — the + // OS arbitrates positional writes to non-overlapping ranges. + FileWriter writer; + writer.Open(local_path, blob_size); + + // Flush the resume sidecar roughly every 16 MB of completed chunks, so a + // hard crash re-downloads at most that much on resume — a fixed bound, + // independent of blob size. Checked only at chunk completion, so it never + // flushes faster than chunks arrive. + constexpr int64_t kBytesPerSidecarSave = 16 * 1024 * 1024; + const int32_t save_interval = + std::max(1, static_cast(kBytesPerSidecarSave / kChunkSize)); + std::atomic chunks_since_save{0}; + + std::mutex error_mutex; + std::exception_ptr first_error; + + // Worker pool: workers race to claim from `pending` via atomic fetch_add. + // On any failure, the first worker to fail records the error, sets + // internal_cancel, and calls azure_ctx.Cancel(); other workers see the + // signal and exit fast. + std::atomic next_pending_idx{0}; + int worker_count = std::min(max_concurrency, static_cast(pending.size())); + if (worker_count < 1) { + worker_count = 1; + } + std::vector> workers; + workers.reserve(static_cast(worker_count)); + + auto worker_body = [&]() { + // Per-worker scratch buffer reused across every chunk this worker + // handles. Streaming downloads fill the scratch in 64 KB pieces and + // forward each piece to the sink, so total transient memory is bounded + // by `worker_count * kStreamingBufferBytes` regardless of chunk size. + std::vector scratch(kStreamingBufferBytes); + + while (true) { + // External cancellation drains the pool as fast as the SDK can unwind. + if (cancelled && cancelled->load(std::memory_order_relaxed)) { + if (!internal_cancel.exchange(true)) { + azure_ctx.Cancel(); + } + return; + } + if (internal_cancel.load(std::memory_order_relaxed)) { + return; + } - // Mutex protects concurrent writes to different offsets in the same file. - // Each chunk opens the file, seeks, and writes — the mutex prevents interleaved I/O. - std::mutex file_mutex; + size_t i = next_pending_idx.fetch_add(1, std::memory_order_relaxed); + if (i >= pending.size()) { + return; + } + int32_t chunk_idx = pending[i]; + int64_t offset = static_cast(chunk_idx) * kChunkSize; + int64_t size = std::min(kChunkSize, blob_size - offset); + + // Sink advances a per-chunk write cursor and forwards each piece to + // the file writer. The writer is responsible for any synchronization + // needed across concurrent workers; we don't take a mutex here. + int64_t written = 0; + auto sink = [&](const uint8_t* data, size_t len) { + writer.WriteAt(offset + written, data, len); + written += static_cast(len); + }; + + try { + DownloadChunkStreaming(chunk_ctx, offset, size, scratch, sink); + + // Account for this chunk and fire the progress callback within the same + // try as the download: on user cancellation bytes_written_cb throws, and + // the catch below runs azure_ctx.Cancel() so peers blocked mid-chunk are + // interrupted immediately rather than only noticing the cancel flag when + // they finish their current chunk. + // Report the global running total so progress stays monotonically + // non-decreasing: concurrent workers complete chunks out of order, and + // the public progress contract must never hand the callback a smaller + // percentage after a larger one. + bytes_completed.fetch_add(size, std::memory_order_relaxed); + if (bytes_written_cb) { + bytes_written_cb(bytes_completed.load(std::memory_order_relaxed)); + } + } catch (...) { + std::lock_guard lock(error_mutex); + if (!first_error) { + first_error = std::current_exception(); + } + if (!internal_cancel.exchange(true)) { + azure_ctx.Cancel(); + } + return; + } - // Download chunks concurrently using a bounded pool of async tasks. - // We launch up to max_concurrency tasks at a time, then wait for the batch to complete. - for (int64_t batch_start = 0; batch_start < num_chunks; batch_start += max_concurrency) { - // Check cancellation between batches - if (cancelled && cancelled->load(std::memory_order_relaxed)) { - ctx.Cancel(); - FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); + bool should_save = false; + { + std::lock_guard lock(state->mutex()); + state->MarkChunkComplete(chunk_idx); + int32_t inc = chunks_since_save.fetch_add(1, std::memory_order_relaxed) + 1; + // Skip the periodic save once every chunk is done: the finalization + // path below deletes the sidecar on success, so writing a fully + // complete sidecar here would just be undone microseconds later. + if (inc >= save_interval && !state->IsComplete()) { + chunks_since_save.store(0, std::memory_order_relaxed); + should_save = true; + } + } + if (should_save) { + std::lock_guard lock(state->mutex()); + state->SaveState(logger_); + } } + }; - int64_t batch_end = std::min(batch_start + max_concurrency, num_chunks); - std::vector> futures; - futures.reserve(static_cast(batch_end - batch_start)); - - for (int64_t chunk_idx = batch_start; chunk_idx < batch_end; ++chunk_idx) { - int64_t offset = chunk_idx * kChunkSize; - int64_t size = std::min(kChunkSize, blob_size - offset); - - futures.push_back(std::async(std::launch::async, - [&blob_client, &local_path, &file_mutex, &bytes_completed, &bytes_written_cb, - &ctx, offset, size]() { - // Download this range from the blob. - // Retry and backoff are handled by the SDK's retry policy. - Azure::Storage::Blobs::DownloadBlobOptions range_opts; - range_opts.Range = Azure::Core::Http::HttpRange{offset, size}; - auto result = blob_client.Download(range_opts, ctx); - auto& body_stream = *result.Value.BodyStream; - - // Read the body into a local buffer - std::vector buffer(static_cast(size)); - size_t total_read = 0; - while (total_read < static_cast(size)) { - size_t bytes_read = body_stream.Read( - buffer.data() + total_read, - static_cast(size) - total_read, - ctx); - - if (bytes_read == 0) { - break; - } - - total_read += bytes_read; - } - - // a zero-byte read before reaching `size` indicates the server closed early. - // Treat as a hard error rather than silently writing a truncated chunk. - if (total_read < static_cast(size)) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "short read from blob stream: got " + - std::to_string(total_read) + " of " + - std::to_string(size) + " bytes at offset " + - std::to_string(offset)); - } - - // Write the chunk to the file at the correct offset - { - std::lock_guard lock(file_mutex); - std::ofstream f(local_path, - std::ios::binary | std::ios::in | std::ios::out); - if (!f.is_open()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to open blob file for write: " + local_path); - } - - f.seekp(offset); - f.write(reinterpret_cast(buffer.data()), - static_cast(total_read)); - if (f.fail()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to write blob chunk to " + local_path + - " at offset " + std::to_string(offset) + - " (" + std::to_string(total_read) + " bytes)"); - } - } - - // Report progress - bytes_completed += static_cast(total_read); - if (bytes_written_cb) { - bytes_written_cb(bytes_completed.load()); - } - })); - } + for (int w = 0; w < worker_count; ++w) { + workers.push_back(std::async(std::launch::async, worker_body)); + } - // Wait for all tasks in this batch, cancelling context on failure + for (auto& f : workers) { try { - for (auto& f : futures) { - f.get(); - } + f.get(); } catch (...) { - // Cancel remaining in-flight downloads so futures complete quickly - ctx.Cancel(); - for (auto& f : futures) { - try { - if (f.valid()) { - f.get(); - } - } catch (...) { - } + // Worker bodies should already have routed exceptions through + // first_error, but stay defensive in case std::async signals one. + std::lock_guard lock(error_mutex); + if (!first_error) { + first_error = std::current_exception(); } - throw; + internal_cancel.store(true, std::memory_order_relaxed); } } + + // Release the OS handle before persisting / deleting the sidecar so any + // observer that watches the data file sees a fully-closed handle. + writer.Close(); + + const bool was_cancelled = cancelled && cancelled->load(std::memory_order_relaxed); + if (first_error || was_cancelled) { + // Persist what we have so the next attempt resumes from here. + { + std::lock_guard lock(state->mutex()); + state->SaveState(logger_); + } + if (was_cancelled) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); + } + std::rethrow_exception(first_error); + } + + // All chunks done — sidecar is no longer needed. + BlobDownloadState::DeleteState(local_path, logger_); } catch (const Azure::Core::OperationCancelledException&) { FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); } catch (const Azure::Core::RequestFailedException& e) { @@ -247,6 +410,34 @@ std::string ComputeRelativePath(const std::string& prefix, const std::string& bl return blob_name.substr(trim); } +/// Returns false if a file at `local_path` already matches the blob's expected +/// `content_length` exactly AND has no `.dlstate` sidecar — in which case the +/// caller can skip the download. Returns true (download needed) for any of: +/// missing file, size mismatch, sidecar present (file may be pre-allocated +/// with holes), or filesystem-stat errors (treat as "redownload to be safe"). +bool IsDownloadNeeded(const BlobItemInfo& blob, const std::string& local_path) { + std::error_code ec; + auto status = std::filesystem::status(local_path, ec); + if (ec || !std::filesystem::exists(status) || !std::filesystem::is_regular_file(status)) { + return true; + } + auto size = std::filesystem::file_size(local_path, ec); + if (ec) { + return true; + } + if (static_cast(size) != blob.content_length) { + return true; + } + // The data file is at the expected size, but a sidecar means a previous run + // pre-allocated then aborted mid-download. The file has holes; let + // AzureBlobDownloader resume from the sidecar. + auto sidecar = BlobDownloadState::GetStateFilePath(local_path); + if (std::filesystem::exists(sidecar, ec)) { + return true; + } + return false; +} + } // anonymous namespace void DownloadBlobsToDirectory(IBlobDownloader& downloader, @@ -304,25 +495,60 @@ void DownloadBlobsToDirectory(IBlobDownloader& downloader, return a.first.content_length < b.first.content_length; }); - // Step 4: Calculate total size for progress + // Step 4: Calculate total size across every in-scope blob, including those + // already present on disk. int64_t total_size = 0; for (const auto& [blob, _] : blobs_to_download) { total_size += blob.content_length; } - // Step 4.5: Emit 0% so callers know the download has started + // Step 5: Skip blobs already present at the expected size. Their bytes + // count toward "downloaded" so the percentage stays accurate when this is a + // resume of a partially-completed download. + int64_t skipped_bytes = 0; + blobs_to_download.erase( + std::remove_if(blobs_to_download.begin(), blobs_to_download.end(), + [&skipped_bytes](const auto& pair) { + if (IsDownloadNeeded(pair.first, pair.second)) { + return false; + } + skipped_bytes += pair.first.content_length; + return true; + }), + blobs_to_download.end()); + + // Step 6: Emit initial progress reflecting any already-on-disk bytes. + // If everything was skipped, emit 100% directly and return. + if (blobs_to_download.empty()) { + if (options.progress) { + options.progress(100.0f); + } + return; + } + if (options.progress) { - int result = options.progress(0.0f); + float initial_percent = + total_size > 0 ? static_cast(skipped_bytes) / static_cast(total_size) * 100.0f : 0.0f; + int result = options.progress(initial_percent); if (result != 0) { FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled by user callback return value"); } } - // Step 5: Download each blob with per-chunk progress. + // Step 7: Download each blob with per-chunk progress. // The cancellation flag is set when the progress callback returns non-zero. // It is shared with chunk download threads so they can exit promptly. std::atomic cancelled{false}; - std::atomic total_downloaded_bytes{0}; + // Seed with skipped bytes so per-chunk progress callbacks compute the right + // overall percentage. + std::atomic total_downloaded_bytes{skipped_bytes}; + + // The user progress callback can be reached from up to max_concurrency chunk + // worker threads at once (per_chunk_progress below). Serialize it so a + // caller's callback (UI handle, counter, logger, IPC) is never entered + // concurrently — the public download progress API does not require callers to + // be thread-safe. + std::mutex progress_mutex; for (const auto& [blob, local_path] : blobs_to_download) { // Check cancellation between blobs @@ -348,7 +574,11 @@ void DownloadBlobsToDirectory(IBlobDownloader& downloader, overall = std::min(overall, total_size); float percent = static_cast(overall) / static_cast(total_size) * 100.0f; - int result = options.progress(percent); + int result; + { + std::lock_guard lock(progress_mutex); + result = options.progress(percent); + } if (result != 0) { cancelled.store(true, std::memory_order_relaxed); FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled by user callback return value"); diff --git a/sdk_v2/cpp/src/download/blob_downloader.h b/sdk_v2/cpp/src/download/blob_downloader.h index f43774a1..5175d0d5 100644 --- a/sdk_v2/cpp/src/download/blob_downloader.h +++ b/sdk_v2/cpp/src/download/blob_downloader.h @@ -11,6 +11,8 @@ namespace fl { +class ILogger; + /// Progress callback: percent is 0.0 to 100.0. Return 0 to continue, non-zero to cancel. using DownloadProgressFn = std::function; @@ -57,8 +59,23 @@ class IBlobDownloader { }; /// Azure Storage Blobs SDK-based implementation of IBlobDownloader. +/// +/// Implements resumable downloads: a `.dlstate` sidecar tracks which 2 MB +/// chunks have completed, and DownloadBlob picks up where a prior aborted run +/// left off. A linked cancellation token cascades the first chunk-level +/// failure to every other in-flight chunk so the worker pool drains quickly. +/// +/// Chunks stream from the blob client into the local file in ~64 KB pieces +/// via a sink callback, so each worker holds a single 64 KB scratch buffer +/// instead of allocating a full chunk's worth of bytes per request. This +/// caps peak memory at roughly `max_concurrency * 64 KB` regardless of how +/// large the blob or the chunk size is. class AzureBlobDownloader : public IBlobDownloader { public: + /// `logger` receives diagnostics only (state-file save/load events). It is required: + /// the orchestrator always has a logger, so there is no optional/null case to handle. + explicit AzureBlobDownloader(ILogger& logger); + std::vector ListBlobs(const std::string& sas_uri) override; void DownloadBlob(const std::string& sas_uri, @@ -67,6 +84,38 @@ class AzureBlobDownloader : public IBlobDownloader { int max_concurrency, BlobBytesWrittenFn bytes_written_cb = nullptr, std::atomic* cancelled = nullptr) override; + + protected: + /// Opaque per-blob context. Defined in `blob_downloader.cc`; holds the Azure + /// SDK BlobClient + Context pointers used by the production virtuals. + struct ChunkContext; + + /// Return the blob size in bytes. Production calls `BlobClient::GetProperties`. + virtual int64_t GetBlobSize(ChunkContext& ctx); + + /// Read `size` bytes starting at `offset` from the blob and forward them + /// piecewise to `sink`. Pulls from the blob client referenced by `ctx`. + /// + /// `scratch` is a per-worker reusable buffer (default 64 KB). `sink` must be + /// invoked with strictly contiguous ranges; the cumulative byte count + /// delivered to `sink` must equal `size` on success. + /// + /// Must throw on failure. Implementations should observe the cancellation + /// flag accessible via `ctx` and exit promptly when cancellation is requested. + virtual void DownloadChunkStreaming(ChunkContext& ctx, + int64_t offset, + int64_t size, + std::vector& scratch, + const std::function& sink); + + /// Reports whether cooperative cancellation has been requested for this + /// download. The orchestrator calls `Azure::Core::Context::Cancel()` after a + /// sibling chunk fails or on external cancellation, and the Azure SDK + /// interrupts in-flight transfers as a result. + bool IsCancellationRequested(ChunkContext& ctx); + + private: + ILogger& logger_; }; /// High-level download function: enumerate, filter, and download all blobs from a SAS URI. diff --git a/sdk_v2/cpp/src/download/cross_process_file_lock.cc b/sdk_v2/cpp/src/download/cross_process_file_lock.cc new file mode 100644 index 00000000..5c01a334 --- /dev/null +++ b/sdk_v2/cpp/src/download/cross_process_file_lock.cc @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/cross_process_file_lock.h" +#include "exception.h" +#include "logger.h" + +#include + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace fl { + +namespace { + +constexpr const char* kLockFileName = ".download.lock"; + +/// `PID:,Time:\n` +std::string FormatProcessInfo() { +#ifdef _WIN32 + auto pid = static_cast(_getpid()); +#else + auto pid = static_cast(getpid()); +#endif + auto t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + std::ostringstream oss; + oss << "PID:" << pid << ",Time:" << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ") << '\n'; + return oss.str(); +} + +} // namespace + +// Platform-specific resource handle. The destructor here is the only thing +// that releases the lock; CrossProcessFileLock's destructor is defaulted. +#ifdef _WIN32 +struct CrossProcessFileLock::State { + HANDLE handle; + ~State() { + if (handle != INVALID_HANDLE_VALUE) { + // FILE_FLAG_DELETE_ON_CLOSE removes the file when the last handle closes. + CloseHandle(handle); + } + } +}; +#else +struct CrossProcessFileLock::State { + int fd; + std::filesystem::path path; + ~State() { + if (fd >= 0) { + // Unlink before close so the file disappears the instant the lock + // releases; a concurrent acquirer simply recreates it. This is the + // classic flock()+unlink() pattern, and it is safe here because every + // acquirer verifies, while holding the flock, that the inode it locked is + // still the one at `path` (see the fstat/stat check in + // TryAcquireForDirectory). An acquirer that raced in on the old inode + // between our unlink and a third party's recreate will see the inode + // mismatch and retry, so two processes never hold "the lock" at once. + // There is also no protected work between this unlink and close. + ::unlink(path.c_str()); + ::close(fd); + } + } +}; +#endif + +CrossProcessFileLock::CrossProcessFileLock(std::filesystem::path path, + std::unique_ptr state, + ILogger& logger) + : path_(std::move(path)), state_(std::move(state)), logger_(logger) {} + +CrossProcessFileLock::~CrossProcessFileLock() { + // Release the OS handle first so the "released" log message is accurate. + state_.reset(); + logger_.Log(LogLevel::Debug, "CrossProcessFileLock released: " + path_.string()); +} + +std::unique_ptr CrossProcessFileLock::TryAcquireForDirectory( + const std::filesystem::path& directory, ILogger& logger) { + std::error_code ec; + std::filesystem::create_directories(directory, ec); + // Best-effort: if create_directories failed, the platform open below will + // surface a clearer error message. + + auto lock_path = directory / kLockFileName; + std::unique_ptr state; + +#ifdef _WIN32 + // dwShareMode=0 blocks any other open (cross- and in-process) until this + // handle closes. FILE_FLAG_DELETE_ON_CLOSE pairs OPEN_ALWAYS into a + // self-cleaning lock that doesn't require unlink-then-close races. + auto wide = lock_path.wstring(); + HANDLE handle = CreateFileW(wide.c_str(), + GENERIC_READ | GENERIC_WRITE, + 0, + nullptr, + OPEN_ALWAYS, + FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE, + nullptr); + if (handle == INVALID_HANDLE_VALUE) { + DWORD err = GetLastError(); + if (err == ERROR_SHARING_VIOLATION || err == ERROR_LOCK_VIOLATION || err == ERROR_ACCESS_DENIED) { + // SHARING/LOCK_VIOLATION: another handle already holds the share-none + // lock. ACCESS_DENIED: the holder is mid-release — FILE_FLAG_DELETE_ON_CLOSE + // puts the file into STATUS_DELETE_PENDING during the close window, and a + // concurrent open of a delete-pending file is reported as ACCESS_DENIED. + // All three mean "another process has it"; treat as contention so the + // caller retries. (A genuine permission error also lands here and would + // poll until timeout, but the directory was just created successfully so + // that is improbable.) + return nullptr; + } + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "CreateFileW failed for lock '" + lock_path.string() + + "' (GetLastError=" + std::to_string(err) + ")"); + } + + auto info = FormatProcessInfo(); + DWORD written = 0; + WriteFile(handle, info.data(), static_cast(info.size()), &written, nullptr); + FlushFileBuffers(handle); + + state = std::unique_ptr(new State{handle}); +#else + int fd = ::open(lock_path.c_str(), O_CREAT | O_RDWR | O_CLOEXEC, 0644); + if (fd < 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "open failed for lock '" + lock_path.string() + "' (errno=" + std::to_string(errno) + ")"); + } + if (::flock(fd, LOCK_EX | LOCK_NB) != 0) { + int err = errno; + ::close(fd); + if (err == EWOULDBLOCK || err == EAGAIN) { + return nullptr; + } + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "flock failed for '" + lock_path.string() + "' (errno=" + std::to_string(err) + ")"); + } + + // Robust-flock inode check. We now hold an exclusive flock on whatever inode + // `fd` refers to, but a releaser unlink()s the lock file in its destructor — + // so between our open() and flock() the path may have been unlinked and a + // third process may have recreated it. If so, we are holding a lock on an + // orphaned inode that guards nothing while the live file at `lock_path` is a + // different inode. Confirm the inode we locked is still the one at the path; + // if not, drop it and report contention so the caller retries against the + // live file. This closes the flock()+unlink() orphan-inode race, which is + // what lets two processes never both believe they hold the lock. + struct stat fd_stat {}; + struct stat path_stat {}; + if (::fstat(fd, &fd_stat) != 0 || ::stat(lock_path.c_str(), &path_stat) != 0 || + fd_stat.st_dev != path_stat.st_dev || fd_stat.st_ino != path_stat.st_ino) { + ::close(fd); // releases the flock on the stale / orphaned inode + return nullptr; + } + + (void)::ftruncate(fd, 0); + auto info = FormatProcessInfo(); + (void)::write(fd, info.data(), info.size()); + + state = std::unique_ptr(new State{fd, lock_path}); +#endif + + logger.Log(LogLevel::Debug, "CrossProcessFileLock acquired: " + lock_path.string()); + return std::unique_ptr( + new CrossProcessFileLock(std::move(lock_path), std::move(state), logger)); +} + +std::unique_ptr CrossProcessFileLock::WaitForDirectoryLock( + const std::filesystem::path& directory, + const CancellationPredicate& is_cancelled, + ILogger& logger, + std::chrono::milliseconds poll_interval, + std::chrono::milliseconds timeout) { + auto deadline = std::chrono::steady_clock::now() + timeout; + // `is_cancelled` is the caller's progress callback, which also serves as the + // liveness heartbeat — it emits 0% on every invocation. We therefore poll it + // on a single cadence (once per `poll_interval`) rather than on a separate + // fast cancellation tick: a faster tick would spam the user callback (~10x/s) + // for the entire wait, and cancelling a multi-minute cross-process wait a + // second sooner is imperceptible. There is no separate cancellation channel + // to decouple the heartbeat from. + while (true) { + if (is_cancelled && is_cancelled()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "lock acquisition cancelled"); + } + auto lock = CrossProcessFileLock::TryAcquireForDirectory(directory, logger); + if (lock) { + return lock; + } + if (std::chrono::steady_clock::now() >= deadline) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "timed out waiting for cross-process download lock on '" + directory.string() + "'"); + } + std::this_thread::sleep_for(poll_interval); + } +} + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/cross_process_file_lock.h b/sdk_v2/cpp/src/download/cross_process_file_lock.h new file mode 100644 index 00000000..6c206275 --- /dev/null +++ b/sdk_v2/cpp/src/download/cross_process_file_lock.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include + +namespace fl { + +class ILogger; + +/// RAII exclusive lock backed by an OS-level file lock on +/// `/.download.lock`. Serializes model downloads across processes +/// that share a cache directory. A crash while holding the lock may leave a +/// zero-byte file behind; the next acquirer reopens and re-locks, so the leak +/// is harmless. +class CrossProcessFileLock { + public: + /// Returning true aborts WaitForDirectoryLock with FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED. + using CancellationPredicate = std::function; + + /// Non-blocking acquisition. Returns nullptr if another process currently + /// holds the lock. Creates `directory` if missing. Throws fl::Exception on + /// unexpected errors (permission denied, etc.). `logger` receives acquire/ + /// release diagnostics and is required — callers always have one. + static std::unique_ptr TryAcquireForDirectory( + const std::filesystem::path& directory, + ILogger& logger); + + /// Polls TryAcquireForDirectory until the lock is acquired, `is_cancelled()` + /// returns true, or `timeout` elapses. + /// Throws FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED on cancellation, or + /// FOUNDRY_LOCAL_ERROR_INTERNAL on timeout. + static std::unique_ptr WaitForDirectoryLock( + const std::filesystem::path& directory, + const CancellationPredicate& is_cancelled, + ILogger& logger, + std::chrono::milliseconds poll_interval = std::chrono::milliseconds{1250}, + std::chrono::milliseconds timeout = std::chrono::hours{3}); + + ~CrossProcessFileLock(); + + CrossProcessFileLock(const CrossProcessFileLock&) = delete; + CrossProcessFileLock& operator=(const CrossProcessFileLock&) = delete; + CrossProcessFileLock(CrossProcessFileLock&&) = delete; + CrossProcessFileLock& operator=(CrossProcessFileLock&&) = delete; + + /// Path to the lock file (for diagnostics / tests). + const std::filesystem::path& path() const noexcept { return path_; } + + private: + struct State; // Platform-specific; defined in the .cc. + + CrossProcessFileLock(std::filesystem::path path, std::unique_ptr state, ILogger& logger); + + std::filesystem::path path_; + std::unique_ptr state_; + ILogger& logger_; +}; + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/download_manager.cc b/sdk_v2/cpp/src/download/download_manager.cc index 6e3bd64c..b576d3ee 100644 --- a/sdk_v2/cpp/src/download/download_manager.cc +++ b/sdk_v2/cpp/src/download/download_manager.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "download/download_manager.h" +#include "download/cross_process_file_lock.h" #include "download/inference_model_writer.h" #include "exception.h" +#include "log_level.h" +#include "logger.h" #include "util/path_safety.h" #include "util/region_fallback.h" #include "utils.h" @@ -12,6 +15,8 @@ #include #include #include +#include +#include #include #include @@ -176,9 +181,10 @@ DownloadManager::DownloadManager(std::string cache_directory, std::string_view c : cache_directory_(std::move(cache_directory)), config_region_(NormalizeConfiguredRegion(catalog_region)), max_concurrency_(max_concurrency), + logger_(logger), registry_client_(std::make_unique( kDefaultRegistryRegion, logger, std::make_unique(logger, !disable_region_fallback))), - blob_downloader_(std::make_unique()) {} + blob_downloader_(std::make_unique(logger)) {} DownloadManager::~DownloadManager() = default; @@ -233,15 +239,14 @@ std::string DownloadManager::ComputeModelPath(const ModelInfo& info) const { std::string DownloadManager::DownloadModel(const ModelInfo& info, std::function progress_cb) { - // Serialize all downloads. Concurrent downloads of the same model would race into - // creating the same directory and double-writing inference_model.json; concurrent - // downloads of different models would compete for the same per-blob chunk parallelism. - // A single global lock keeps the model simple and predictable. - std::lock_guard download_guard(download_mutex_); - + // Serialize all model downloads in this process: only one runs at a time, so it + // gets the full network and disk instead of competing with another download. + // The cross-process file lock taken below extends the guarantee across every + // process and app that shares this cache directory. + std::unique_lock download_guard(download_mutex_); auto model_path = ComputeModelPath(info); - // Check if already downloaded (before validating URI — cached models don't need one). + // Fast path: serve the cache without taking the cross-process lock. // A valid cache hit requires: directory exists, no in-progress signal file, and // inference_model.json is present (written by DownloadModel on successful completion). auto signal_path = std::filesystem::path(model_path) / kDownloadSignalFileName; @@ -260,9 +265,45 @@ std::string DownloadManager::DownloadModel(const ModelInfo& info, FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "cannot download model: empty URI (asset_id)"); } - // Create output directory + // Create output directory before taking the cross-process lock, since the lock + // file lives inside it. std::filesystem::create_directories(model_path); + // Serialize across processes that share this cache directory. Inside the + // running process the download mutex already serializes downloads; the file + // lock protects against a second SDK instance (e.g. another service or CLI) + // racing on the same model directory. + auto cancel_pred = [&progress_cb]() -> bool { + // progress_cb returning non-zero is the SDK's cancellation signal. Reusing + // it here also acts as a periodic heartbeat (0%) while we wait for the + // other process to finish. + return progress_cb && progress_cb(0.0f) != 0; + }; + auto lock = CrossProcessFileLock::TryAcquireForDirectory(model_path, logger_); + if (!lock) { + logger_.Log(LogLevel::Information, + "Model download is being performed by another process. Waiting on lock at '" + + model_path + "'..."); + // Don't hold the in-process download mutex while blocking on the cross-process + // lock: that wait can last minutes to hours (another process is downloading), + // and freezing every unrelated in-process model download for that long is far + // worse than the bandwidth contention this mutex exists to prevent. Release it + // for the wait and re-acquire before the cache re-check + download below. + download_guard.unlock(); + lock = CrossProcessFileLock::WaitForDirectoryLock(model_path, cancel_pred, logger_); + download_guard.lock(); + } + + // Another process may have just completed the download we were waiting on. + // Re-check the cache now that we hold the lock. + if (std::filesystem::exists(model_path) && !std::filesystem::exists(signal_path) && + HasInferenceModelJson(model_path)) { + if (progress_cb) { + progress_cb(100.0f); + } + return ResolveEffectiveModelPath(model_path); + } + // Create download signal file { std::ofstream signal(signal_path); diff --git a/sdk_v2/cpp/src/download/download_manager.h b/sdk_v2/cpp/src/download/download_manager.h index c552101b..7099dcb8 100644 --- a/sdk_v2/cpp/src/download/download_manager.h +++ b/sdk_v2/cpp/src/download/download_manager.h @@ -74,13 +74,14 @@ class DownloadManager { // from config. std::string config_region_; int max_concurrency_; + ILogger& logger_; std::unique_ptr registry_client_; std::unique_ptr blob_downloader_; - /// Serializes all DownloadModel calls. Only one model downloads at a time — simpler - /// than per-model locking and avoids contending with the per-blob chunk parallelism - /// (`max_concurrency_`) inside a single download. - mutable std::mutex download_mutex_; + /// Serializes all model downloads in this process: only one runs at a time, so + /// each gets the full network/disk instead of competing with another download. + /// Cross-process serialization is handled separately by CrossProcessFileLock. + std::mutex download_mutex_; }; } // namespace fl diff --git a/sdk_v2/cpp/src/download/file_writer.cc b/sdk_v2/cpp/src/download/file_writer.cc new file mode 100644 index 00000000..0ac02d98 --- /dev/null +++ b/sdk_v2/cpp/src/download/file_writer.cc @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/file_writer.h" +#include "exception.h" + +#include + +#include +#include +#include + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#include +#include +#include +#include +#endif + +namespace fl { + +namespace fs = std::filesystem; + +namespace { + +/// Ensure the data file exists at exactly `expected_size`, recreating it at the +/// new size if it currently differs (larger or smaller). An existing file that +/// is already the right size is left intact — the resume path relies on this. +void EnsureFileExistsAtSize(const fs::path& path, int64_t expected_size) { + std::error_code ec; + auto cur_size = fs::file_size(path, ec); + if (!ec) { + if (cur_size == static_cast(expected_size)) { + return; + } + // File exists but is the wrong size — fall through to recreate. + } else if (ec != std::errc::no_such_file_or_directory) { + // Some other stat error (permission, transient NFS hiccup, AV scanner + // holding a handle, etc.). Don't blow away a potentially-intact file just + // because we couldn't read its size; surface the error instead so the + // caller can retry and the existing on-disk progress is preserved. + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to stat blob file: " + path.string() + " (" + ec.message() + ")"); + } + + std::ofstream f(path, std::ios::binary); + if (!f.is_open()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to open blob file for pre-allocation: " + path.string()); + } + if (expected_size > 0) { + f.seekp(expected_size - 1); + f.put('\0'); + } + f.close(); + if (f.fail()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to pre-allocate blob file: " + path.string() + + " (size=" + std::to_string(expected_size) + ")"); + } +} + +} // namespace + +#ifdef _WIN32 + +FileWriter::~FileWriter() { Close(); } + +void FileWriter::Open(const fs::path& path, int64_t expected_size) { + EnsureFileExistsAtSize(path, expected_size); + // FILE_SHARE_READ | FILE_SHARE_WRITE so the lock file / other tools can peek + // at the partial file without us erroring; positional WriteFile is safe + // regardless of share mode. + HANDLE h = ::CreateFileW(path.wstring().c_str(), GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE, nullptr, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, nullptr); + if (h == INVALID_HANDLE_VALUE) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter open failed for " + path.string() + " (Win32 err " + + std::to_string(::GetLastError()) + ")"); + } + handle_ = h; +} + +void FileWriter::WriteAt(int64_t offset, const uint8_t* data, size_t len) { + // Concurrent WriteFile calls with distinct OVERLAPPED offsets on the same + // handle are safe for non-overlapping ranges; the kernel orders them. + while (len > 0) { + OVERLAPPED ov{}; + // Split the 64-bit file offset across the OVERLAPPED halves: the DWORD casts + // keep the low 32 bits in Offset and the high 32 bits in OffsetHigh. + ov.Offset = static_cast(static_cast(offset)); + ov.OffsetHigh = static_cast(static_cast(offset) >> 32); + DWORD to_write = static_cast(len > 0x7FFFFFFFu ? 0x7FFFFFFFu : len); + DWORD written = 0; + if (!::WriteFile(handle_, data, to_write, &written, &ov)) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter write failed at offset " + std::to_string(offset) + " (Win32 err " + + std::to_string(::GetLastError()) + ")"); + } + if (written == 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter short write at offset " + std::to_string(offset)); + } + offset += static_cast(written); + data += written; + len -= written; + } +} + +void FileWriter::Close() { + if (handle_ != nullptr) { + ::CloseHandle(handle_); + handle_ = nullptr; + } +} + +#else // POSIX + +FileWriter::~FileWriter() { Close(); } + +void FileWriter::Open(const fs::path& path, int64_t expected_size) { + EnsureFileExistsAtSize(path, expected_size); + fd_ = ::open(path.c_str(), O_RDWR | O_CLOEXEC); + if (fd_ < 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter open failed for " + path.string() + " (errno " + + std::to_string(errno) + ")"); + } +} + +void FileWriter::WriteAt(int64_t offset, const uint8_t* data, size_t len) { + while (len > 0) { + ssize_t n = ::pwrite(fd_, data, len, static_cast(offset)); + if (n < 0) { + if (errno == EINTR) continue; + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter pwrite failed at offset " + std::to_string(offset) + " (errno " + + std::to_string(errno) + ")"); + } + if (n == 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "FileWriter short pwrite at offset " + std::to_string(offset)); + } + offset += n; + data += n; + len -= static_cast(n); + } +} + +void FileWriter::Close() { + if (fd_ >= 0) { + ::close(fd_); + fd_ = -1; + } +} + +#endif + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/file_writer.h b/sdk_v2/cpp/src/download/file_writer.h new file mode 100644 index 00000000..0be20021 --- /dev/null +++ b/sdk_v2/cpp/src/download/file_writer.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +namespace fl { + +/// Thread-safe positional writer for blob downloads. +/// +/// Workers in a single download claim disjoint chunks, so concurrent `WriteAt` +/// calls always target non-overlapping byte ranges. Backed by `pwrite` (POSIX) +/// or `WriteFile` + `OVERLAPPED` (Windows): the OS arbitrates concurrent writes +/// to disjoint ranges, so no user-space lock is taken. +class FileWriter { + public: + FileWriter() = default; + ~FileWriter(); + + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + + /// Make `path` exist at exactly `expected_size` bytes. If the file already + /// exists at that size, leave its contents intact so the resume path can pick + /// up where it left off. Called once before the first `WriteAt`. + void Open(const std::filesystem::path& path, int64_t expected_size); + + /// Write `len` bytes from `data` starting at byte offset `offset`. Safe for + /// concurrent calls targeting disjoint ranges. + void WriteAt(int64_t offset, const uint8_t* data, size_t len); + + /// Release the underlying OS handle. Implicitly called by the destructor. + void Close(); + + private: +#ifdef _WIN32 + // Win32 HANDLE. Holds a valid handle while open, nullptr otherwise. + void* handle_ = nullptr; +#else + int fd_ = -1; +#endif +}; + +} // namespace fl diff --git a/sdk_v2/cpp/test/CMakeLists.txt b/sdk_v2/cpp/test/CMakeLists.txt index 08e23caf..fb4aa165 100644 --- a/sdk_v2/cpp/test/CMakeLists.txt +++ b/sdk_v2/cpp/test/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(foundry_local_tests internal_api/audio/audio_transcription_contract_test.cc internal_api/audio/pcm_utils_test.cc internal_api/base_model_catalog_test.cc + internal_api/blob_download_state_test.cc internal_api/c_api_test.cc internal_api/callback_handler_test.cc internal_api/catalog_cache_test.cc @@ -21,6 +22,7 @@ add_executable(foundry_local_tests internal_api/chat_completions_test.cc internal_api/chat_completions_converter_test.cc internal_api/configuration_test.cc + internal_api/cross_process_file_lock_test.cc internal_api/download_test.cc internal_api/embeddings/contracts_embeddings_test.cc internal_api/embeddings/fp16_test.cc @@ -28,6 +30,7 @@ add_executable(foundry_local_tests internal_api/exception_test.cc internal_api/execution_provider_test.cc internal_api/file_uri_test.cc + internal_api/file_writer_test.cc internal_api/genai_config_test.cc internal_api/http_retry_test.cc internal_api/item_test.cc diff --git a/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc b/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc new file mode 100644 index 00000000..cb1fefbc --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/blob_download_state.h" +#include "test_helpers.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +using namespace fl; + +namespace { + +class TempDir { + public: + TempDir() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist; + path_ = fs::temp_directory_path() / ("fl_dlstate_test_" + std::to_string(dist(gen))); + fs::create_directories(path_); + } + + ~TempDir() { + std::error_code ec; + fs::remove_all(path_, ec); + } + + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +constexpr int64_t kBlobSize = 20 * 1024 * 1024; // 20 MiB +constexpr int32_t kChunkSize = 2 * 1024 * 1024; // 2 MiB +constexpr int32_t kNumChunks = 10; + +} // namespace + +TEST(BlobDownloadStateTest, GetStateFilePathAppendsDlstate) { + fs::path p = "C:/some/file.bin"; + EXPECT_EQ(BlobDownloadState::GetStateFilePath(p).string(), + (fs::path("C:/some/file.bin.dlstate")).string()); +} + +TEST(BlobDownloadStateTest, CreateNewInitializesEmptyBitmap) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + ASSERT_NE(s, nullptr); + EXPECT_EQ(s->blob_size, kBlobSize); + EXPECT_EQ(s->chunk_size, kChunkSize); + EXPECT_EQ(s->total_chunks, kNumChunks); + EXPECT_EQ(s->completed_count, 0); + EXPECT_EQ(s->highest_completed_chunk, -1); + EXPECT_EQ(s->bitmap_byte_aligned_start, 0); + EXPECT_FALSE(s->IsComplete()); + EXPECT_EQ(s->CalculateDownloadedSize(), 0); + EXPECT_EQ(s->GetPendingChunks().size(), static_cast(kNumChunks)); +} + +TEST(BlobDownloadStateTest, MarkChunkCompleteUpdatesBitmapAndCounter) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(3); + EXPECT_TRUE(s->IsChunkComplete(3)); + EXPECT_FALSE(s->IsChunkComplete(2)); + EXPECT_EQ(s->completed_count, 1); + EXPECT_EQ(s->highest_completed_chunk, 3); + EXPECT_EQ(s->CalculateDownloadedSize(), kChunkSize); +} + +TEST(BlobDownloadStateTest, MarkChunkCompleteIsIdempotent) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(5); + s->MarkChunkComplete(5); + s->MarkChunkComplete(5); + EXPECT_EQ(s->completed_count, 1); +} + +TEST(BlobDownloadStateTest, CalculateDownloadedSizeAccountsForPartialFinalChunk) { + TempDir d; + auto local = d.path() / "blob.bin"; + constexpr int64_t kOddBlobSize = 5 * 1024 * 1024 + 17; // last chunk is 17 bytes + constexpr int32_t kOddNumChunks = 3; + auto s = BlobDownloadState::CreateNew("blob", local, kOddBlobSize, kChunkSize, kOddNumChunks); + for (int32_t i = 0; i < kOddNumChunks; ++i) { + s->MarkChunkComplete(i); + } + EXPECT_TRUE(s->IsComplete()); + EXPECT_EQ(s->CalculateDownloadedSize(), kOddBlobSize); +} + +TEST(BlobDownloadStateTest, GetPendingChunksReturnsGaps) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i : {0, 1, 2, 5, 7}) { + s->MarkChunkComplete(i); + } + auto pending = s->GetPendingChunks(); + std::vector expected{3, 4, 6, 8, 9}; + EXPECT_EQ(pending, expected); +} + +TEST(BlobDownloadStateTest, SaveAndLoadRoundTrip) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i : {0, 2, 4, 6, 8}) { + s->MarkChunkComplete(i); + } + s->SaveState(fl::test::NullLog()); + } + auto loaded = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks, + fl::test::NullLog()); + ASSERT_NE(loaded, nullptr); + EXPECT_EQ(loaded->completed_count, 5); + EXPECT_EQ(loaded->highest_completed_chunk, 8); + for (int32_t i : {0, 2, 4, 6, 8}) { + EXPECT_TRUE(loaded->IsChunkComplete(i)) << "chunk " << i; + } + for (int32_t i : {1, 3, 5, 7, 9}) { + EXPECT_FALSE(loaded->IsChunkComplete(i)) << "chunk " << i; + } + std::vector expected{1, 3, 5, 7, 9}; + EXPECT_EQ(loaded->GetPendingChunks(), expected); +} + +TEST(BlobDownloadStateTest, SaveStateAdvancesBitmapByteAlignedStart) { + TempDir d; + auto local = d.path() / "blob.bin"; + // Use a large enough total that whole-word advance is meaningful. + constexpr int32_t kBigNumChunks = 200; + constexpr int64_t kBigBlobSize = static_cast(kBigNumChunks) * kChunkSize; + auto s = BlobDownloadState::CreateNew("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks); + // Complete the first 80 chunks (10 full bytes worth). + for (int32_t i = 0; i < 80; ++i) { + s->MarkChunkComplete(i); + } + s->SaveState(fl::test::NullLog()); + // 64 bits = 1 full word; next 16 bits in word 1. Aligned start lands on + // 80 (multiple of 8). + EXPECT_EQ(s->bitmap_byte_aligned_start, 80); + + // Reload and verify the implicit prefix is still considered complete. + auto loaded = BlobDownloadState::LoadState("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks, + fl::test::NullLog()); + ASSERT_NE(loaded, nullptr); + for (int32_t i = 0; i < 80; ++i) { + EXPECT_TRUE(loaded->IsChunkComplete(i)); + } + for (int32_t i = 80; i < kBigNumChunks; ++i) { + EXPECT_FALSE(loaded->IsChunkComplete(i)); + } + EXPECT_EQ(loaded->completed_count, 80); +} + +// Regression: a second SaveState whose contiguous-complete prefix crosses a +// 64-bit word boundary from a non-word-aligned start must not advance +// bitmap_byte_aligned_start past the first still-pending chunk. The advance +// previously accumulated +64 per word onto the unaligned base and overshot by +// (start % 64), silently marking never-downloaded chunks complete on reload. +TEST(BlobDownloadStateTest, SaveStateFromUnalignedStartDoesNotMarkPendingComplete) { + TempDir d; + auto local = d.path() / "blob.bin"; + constexpr int32_t kBigNumChunks = 200; + constexpr int64_t kBigBlobSize = static_cast(kBigNumChunks) * kChunkSize; + auto s = BlobDownloadState::CreateNew("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks); + + // First save lands the contiguous prefix on a byte (8) but not a word (64) + // boundary. + for (int32_t i = 0; i < 8; ++i) { + s->MarkChunkComplete(i); + } + s->SaveState(fl::test::NullLog()); + EXPECT_EQ(s->bitmap_byte_aligned_start, 8); + + // Extend the contiguous prefix across the word boundary: chunks 0..64 done, + // chunk 65 is the first still-pending chunk. + for (int32_t i = 8; i <= 64; ++i) { + s->MarkChunkComplete(i); + } + s->SaveState(fl::test::NullLog()); + // Must round down to 64 (the byte boundary at/below the first pending chunk), + // never overshoot to 72. + EXPECT_EQ(s->bitmap_byte_aligned_start, 64); + + // Reload and prove chunks 65..71 (never downloaded) are still pending. + auto loaded = BlobDownloadState::LoadState("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks, + fl::test::NullLog()); + ASSERT_NE(loaded, nullptr); + EXPECT_TRUE(loaded->IsChunkComplete(64)); + for (int32_t i = 65; i < 72; ++i) { + EXPECT_FALSE(loaded->IsChunkComplete(i)) << "chunk " << i << " was never downloaded"; + } + auto pending = loaded->GetPendingChunks(); + ASSERT_FALSE(pending.empty()); + EXPECT_EQ(pending.front(), 65); +} + +TEST(BlobDownloadStateTest, LoadStateReturnsNullWhenFileMissing) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks, fl::test::NullLog()); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsBadMagic) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto sidecar = BlobDownloadState::GetStateFilePath(local); + { + std::ofstream f(sidecar, std::ios::binary); + f << "ZZZZ"; // wrong magic + f.put(static_cast(0)); // version + for (int i = 0; i < 64; ++i) f.put(0); // padding + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks, fl::test::NullLog()); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsBlobSizeMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(fl::test::NullLog()); + } + // Reload with a *different* expected blob_size — should be rejected. + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize + 1, kChunkSize, kNumChunks, + fl::test::NullLog()); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsChunkSizeMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(fl::test::NullLog()); + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize + 1, kNumChunks, + fl::test::NullLog()); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsTotalChunksMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(fl::test::NullLog()); + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks + 1, + fl::test::NullLog()); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, DeleteStateRemovesSidecar) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(fl::test::NullLog()); + } + EXPECT_TRUE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + BlobDownloadState::DeleteState(local, fl::test::NullLog()); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + // Re-deletion when the file is already absent is a no-op (best-effort). + BlobDownloadState::DeleteState(local, fl::test::NullLog()); +} + +TEST(BlobDownloadStateTest, IsCompleteFlipsTrueWhenAllChunksMarked) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i = 0; i < kNumChunks; ++i) { + EXPECT_FALSE(s->IsComplete()); + s->MarkChunkComplete(i); + } + EXPECT_TRUE(s->IsComplete()); + EXPECT_EQ(s->GetPendingChunks().size(), 0u); +} diff --git a/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc b/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc new file mode 100644 index 00000000..34b46496 --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/cross_process_file_lock.h" +#include "test_helpers.h" + +#include "exception.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#endif + +namespace fs = std::filesystem; + +using namespace fl; + +namespace { + +/// Per-test temp directory. Auto-cleans on destruction so a flaky test never +/// leaks lock files into the system temp dir. +class TempDir { + public: + TempDir() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist; + path_ = fs::temp_directory_path() / ("fl_lock_test_" + std::to_string(dist(gen))); + fs::create_directories(path_); + } + + ~TempDir() { + std::error_code ec; + fs::remove_all(path_, ec); + } + + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +} // namespace + +TEST(CrossProcessFileLockTest, TryAcquireSucceedsForFreshDirectory) { + TempDir dir; + + auto lock = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + + ASSERT_NE(lock, nullptr); + EXPECT_TRUE(fs::exists(lock->path())); + EXPECT_EQ(lock->path().parent_path(), dir.path()); + EXPECT_EQ(lock->path().filename(), ".download.lock"); +} + +TEST(CrossProcessFileLockTest, ReleaseOnDestructionRemovesLockFile) { + TempDir dir; + fs::path lock_file; + + { + auto lock = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(lock, nullptr); + lock_file = lock->path(); + EXPECT_TRUE(fs::exists(lock_file)); + } + + // After RAII release the lock file should be gone (Win FILE_FLAG_DELETE_ON_CLOSE, + // POSIX explicit unlink in destructor). + EXPECT_FALSE(fs::exists(lock_file)); +} + +TEST(CrossProcessFileLockTest, SecondAcquireReturnsNullWhileFirstIsHeld) { + TempDir dir; + auto first = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(first, nullptr); + + auto second = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + EXPECT_EQ(second, nullptr); +} + +TEST(CrossProcessFileLockTest, ReacquireSucceedsAfterRelease) { + TempDir dir; + { + auto first = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(first, nullptr); + } + auto reacquired = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + EXPECT_NE(reacquired, nullptr); +} + +TEST(CrossProcessFileLockTest, CreatesDirectoryIfMissing) { + TempDir parent; + auto missing = parent.path() / "nested" / "model"; + + ASSERT_FALSE(fs::exists(missing)); + + auto lock = CrossProcessFileLock::TryAcquireForDirectory(missing, fl::test::NullLog()); + + ASSERT_NE(lock, nullptr); + EXPECT_TRUE(fs::is_directory(missing)); + EXPECT_TRUE(fs::exists(missing / ".download.lock")); +} + +TEST(CrossProcessFileLockTest, WaitForLockReturnsImmediatelyWhenAvailable) { + TempDir dir; + + auto start = std::chrono::steady_clock::now(); + auto lock = CrossProcessFileLock::WaitForDirectoryLock(dir.path(), []() { return false; }, fl::test::NullLog()); + auto elapsed = std::chrono::steady_clock::now() - start; + + ASSERT_NE(lock, nullptr); + // Fast-path acquisition should be well under 100 ms. + EXPECT_LT(elapsed, std::chrono::milliseconds(500)); +} + +TEST(CrossProcessFileLockTest, WaitForLockAcquiresAfterHolderReleases) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(holder, nullptr); + + // Release the holder after a short delay on another thread. + std::thread releaser([&] { + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + holder.reset(); + }); + + auto start = std::chrono::steady_clock::now(); + auto lock = CrossProcessFileLock::WaitForDirectoryLock( + dir.path(), []() { return false; }, /*logger=*/fl::test::NullLog(), + /*poll_interval=*/std::chrono::milliseconds(100), /*timeout=*/std::chrono::seconds(10)); + auto elapsed = std::chrono::steady_clock::now() - start; + + releaser.join(); + ASSERT_NE(lock, nullptr); + EXPECT_GE(elapsed, std::chrono::milliseconds(200)); + EXPECT_LT(elapsed, std::chrono::seconds(5)); +} + +TEST(CrossProcessFileLockTest, WaitForLockThrowsOnCancellation) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(holder, nullptr); + + std::atomic cancel{false}; + std::thread canceller([&] { + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + cancel.store(true); + }); + + try { + (void)CrossProcessFileLock::WaitForDirectoryLock( + dir.path(), [&cancel]() { return cancel.load(); }, /*logger=*/fl::test::NullLog(), + /*poll_interval=*/std::chrono::milliseconds(100), /*timeout=*/std::chrono::seconds(10)); + canceller.join(); + FAIL() << "expected fl::Exception(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED)"; + } catch (const Exception& ex) { + canceller.join(); + EXPECT_EQ(ex.code(), FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED); + } +} + +TEST(CrossProcessFileLockTest, WaitForLockThrowsOnTimeout) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + ASSERT_NE(holder, nullptr); + + try { + (void)CrossProcessFileLock::WaitForDirectoryLock( + dir.path(), []() { return false; }, /*logger=*/fl::test::NullLog(), + /*poll_interval=*/std::chrono::milliseconds(50), /*timeout=*/std::chrono::milliseconds(200)); + FAIL() << "expected fl::Exception(FOUNDRY_LOCAL_ERROR_INTERNAL)"; + } catch (const Exception& ex) { + EXPECT_EQ(ex.code(), FOUNDRY_LOCAL_ERROR_INTERNAL); + std::string what = ex.what(); + EXPECT_NE(what.find("timed out"), std::string::npos); + } +} + +#ifndef _WIN32 +// A genuine cross-PROCESS test (POSIX, i.e. macOS/Linux): fork a child that +// holds the lock, then verify (a) this process is locked out while the child +// holds it and (b) the kernel releases the flock when the child *exits* — even +// though the child leaves the lock file on disk, mirroring a downloader that +// crashed mid-download. Windows share-none contention is already covered +// in-process by SecondAcquireReturnsNullWhileFirstIsHeld (dwShareMode=0 is +// enforced identically for same- and cross-process opens). +TEST(CrossProcessFileLockTest, HeldAcrossProcessesAndReleasedWhenHolderExits) { + TempDir dir; + const auto acquired_signal = dir.path() / "child_acquired"; + const auto release_signal = dir.path() / "parent_done"; + + const pid_t pid = ::fork(); + ASSERT_NE(pid, -1) << "fork failed"; + + if (pid == 0) { + // CHILD: acquire, announce, wait (bounded) for the parent, then _exit while + // still holding it. _exit skips C++/gtest teardown — correct for a forked + // child — so the lock's destructor never runs and the file is left behind; + // the kernel still drops the flock on process exit. + auto lock = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + if (lock == nullptr) { + _exit(2); + } + { std::ofstream(acquired_signal).put('x'); } + for (int i = 0; i < 200 && !fs::exists(release_signal); ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(25)); + } + _exit(0); + } + + // PARENT: wait for the child to take the lock (up to ~5 s). + bool child_acquired = false; + for (int i = 0; i < 200 && !child_acquired; ++i) { + if (fs::exists(acquired_signal)) { + child_acquired = true; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(25)); + } + } + ASSERT_TRUE(child_acquired) << "child process never acquired the lock"; + + // A different process holds it — we must be locked out. + EXPECT_EQ(CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()), nullptr); + + // Release the child and reap it. + { std::ofstream(release_signal).put('x'); } + int status = 0; + ASSERT_EQ(::waitpid(pid, &status, 0), pid); + EXPECT_TRUE(WIFEXITED(status)) << "child did not exit normally"; + EXPECT_EQ(WEXITSTATUS(status), 0) << "child failed to acquire the lock"; + + // The holder process is gone: the kernel released its flock even though the + // lock file is still on disk, so the next acquirer simply re-locks it. + auto reacquired = CrossProcessFileLock::TryAcquireForDirectory(dir.path(), fl::test::NullLog()); + EXPECT_NE(reacquired, nullptr) << "lock not released after the holder process exited"; +} +#endif // !_WIN32 diff --git a/sdk_v2/cpp/test/internal_api/download_test.cc b/sdk_v2/cpp/test/internal_api/download_test.cc index 38215c6c..291b63f6 100644 --- a/sdk_v2/cpp/test/internal_api/download_test.cc +++ b/sdk_v2/cpp/test/internal_api/download_test.cc @@ -8,7 +8,9 @@ // - DownloadManager (full flow orchestration) #include "catalog/azure_catalog_client.h" #include "catalog/azure_catalog_models.h" +#include "download/blob_download_state.h" #include "download/blob_downloader.h" +#include "download/cross_process_file_lock.h" #include "download/download_manager.h" #include "download/inference_model_writer.h" #include "download/model_registry_client.h" @@ -23,9 +25,13 @@ #include #include +#include #include +#include #include #include +#include +#include #include #include #include @@ -517,6 +523,99 @@ TEST(BlobDownloadTest, HandlesEmptyBlobList) { EXPECT_TRUE(mock.downloaded_blobs.empty()); } +// ======================================================================== +// Skip-existing (Increment 1: resumable downloads) +// ======================================================================== + +TEST(BlobDownloadTest, SkipsExistingFilesWithCorrectSize) { + TempDir tmpdir; + // Pre-create one of the blobs at the expected size on disk. + std::ofstream(tmpdir.path() / "weights.safetensors") << std::string(1000, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"weights.safetensors", 1000}, + {"config.json", 100}, + }; + + BlobDownloadOptions opts; + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + // Only the missing blob should be downloaded. + ASSERT_EQ(mock.downloaded_blobs.size(), 1u); + EXPECT_EQ(mock.downloaded_blobs[0], "config.json"); +} + +TEST(BlobDownloadTest, RedownloadsFilesWithWrongSize) { + TempDir tmpdir; + // Existing file is truncated relative to the expected blob size. + std::ofstream(tmpdir.path() / "weights.safetensors") << std::string(500, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"weights.safetensors", 1000}, + }; + + BlobDownloadOptions opts; + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + // Wrong-size files should be redownloaded (the mock overwrites them). + ASSERT_EQ(mock.downloaded_blobs.size(), 1u); + EXPECT_EQ(mock.downloaded_blobs[0], "weights.safetensors"); +} + +TEST(BlobDownloadTest, ReportsSkippedBytesInInitialProgress) { + TempDir tmpdir; + // 500 of 2000 bytes already on disk → initial progress should be 25%. + std::ofstream(tmpdir.path() / "already.bin") << std::string(500, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"already.bin", 500}, + {"missing.bin", 1500}, + }; + + std::vector progress_values; + BlobDownloadOptions opts; + opts.progress = [&](float pct) { + progress_values.push_back(pct); + return 0; + }; + + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + ASSERT_FALSE(progress_values.empty()); + // First emitted progress reflects the already-on-disk bytes (500/2000 = 25%). + EXPECT_NEAR(progress_values.front(), 100.0f * 500.0f / 2000.0f, 0.5f); + // Final progress must hit 100%. + EXPECT_FLOAT_EQ(progress_values.back(), 100.0f); +} + +TEST(BlobDownloadTest, EmitsHundredPercentWhenEverythingIsCached) { + TempDir tmpdir; + std::ofstream(tmpdir.path() / "a.bin") << std::string(100, 'A'); + std::ofstream(tmpdir.path() / "b.bin") << std::string(200, 'B'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"a.bin", 100}, + {"b.bin", 200}, + }; + + std::vector progress_values; + BlobDownloadOptions opts; + opts.progress = [&](float pct) { + progress_values.push_back(pct); + return 0; + }; + + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + EXPECT_TRUE(mock.downloaded_blobs.empty()); + ASSERT_FALSE(progress_values.empty()); + EXPECT_FLOAT_EQ(progress_values.front(), 100.0f); +} + // ======================================================================== // Path-traversal hardening (security) // ======================================================================== @@ -1072,6 +1171,154 @@ TEST(DownloadManagerTest, ConcurrentDownloadsOfSameModelSerialize) { } } +// All model downloads serialize through the process-wide download_mutex_, even +// for two *different* models. A concurrency probe records the peak number of +// downloads running at once; correct serialization keeps that peak at 1 (the +// second download can't enter until the first releases the mutex). +TEST(DownloadManagerTest, ModelDownloadsSerializeUnderGlobalLock) { + TempDir tmpdir; + DownloadManager manager(tmpdir.string(), "eastus", 64, fl::test::NullLog()); + + auto registry = std::make_unique( + "eastus", fl::test::NullLog(), std::make_unique(fl::test::NullLog(), false), + [](const std::string&) { + return MakeRegistryResponse( + R"({"blobSasUri": "https://storage.blob.core.windows.net/c?sig=test"})"); + }); + manager.SetModelRegistryClient(std::move(registry)); + + // Tracks the peak number of downloads running at once. The global download + // mutex must keep this at 1 even for different models. + class ConcurrencyProbe : public IBlobDownloader { + public: + std::atomic active{0}; + std::atomic peak{0}; + + std::vector ListBlobs(const std::string&) override { + return {{"variant-cpu/weights.bin", 16}}; + } + + void DownloadBlob(const std::string&, const std::string& blob_name, + const std::string& local_path, int, + BlobBytesWrittenFn bytes_written_cb, + std::atomic*) override { + int now = ++active; + int prev = peak.load(); + while (now > prev && !peak.compare_exchange_weak(prev, now)) { + } + // Hold long enough that a second concurrent download would overlap here. + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + --active; + + auto parent = fs::path(local_path).parent_path(); + if (!parent.empty()) { + fs::create_directories(parent); + } + std::ofstream f(local_path); + f << "data for " << blob_name; + if (bytes_written_cb) { + bytes_written_cb(16); + } + } + }; + + auto probe = std::make_unique(); + auto* probe_raw = probe.get(); + manager.SetBlobDownloader(std::move(probe)); + + auto make_info = [](const char* id, const char* publisher) { + ModelInfo info; + info.model_id = id; + info.name = id; + info.uri = std::string("azureml://registries/test/models/") + id + "/versions/1"; + info.string_properties[FOUNDRY_LOCAL_MODEL_PROP_PUBLISHER_STR] = publisher; + return info; + }; + auto info_a = make_info("model-a:1", "PubA"); + auto info_b = make_info("model-b:1", "PubB"); + + std::atomic exceptions{0}; + std::thread t1([&] { + try { + manager.DownloadModel(info_a); + } catch (...) { + ++exceptions; + } + }); + std::thread t2([&] { + try { + manager.DownloadModel(info_b); + } catch (...) { + ++exceptions; + } + }); + t1.join(); + t2.join(); + + EXPECT_EQ(exceptions.load(), 0); + EXPECT_EQ(probe_raw->peak.load(), 1) + << "The global download mutex must serialize all model downloads, even for different models."; +} + +// Exercise the cross-process file-lock branch of DownloadModel that +// the in-process-only concurrency tests never reach. A second process (simulated +// here by holding the lock directly) is mid-download on the same model directory. +// DownloadModel must (1) observe the held lock, (2) block in WaitForDirectoryLock +// without holding the in-process download mutex, and (3) once the lock releases +// AND inference_model.json is present, return the cached result via the post-lock +// recheck WITHOUT re-downloading anything. +TEST(DownloadManagerTest, WaitsForCrossProcessLockThenServesCachedResult) { + TempDir tmpdir; + DownloadManager manager(tmpdir.string(), "eastus", 64, fl::test::NullLog()); + + // Registry + downloader that must stay untouched if the post-lock recheck works. + auto registry = std::make_unique( + "eastus", fl::test::NullLog(), std::make_unique(fl::test::NullLog(), false), + [](const std::string&) { + return MakeRegistryResponse( + R"({"blobSasUri": "https://storage.blob.core.windows.net/c?sig=test"})"); + }); + manager.SetModelRegistryClient(std::move(registry)); + + auto mock = std::make_unique(); + mock->blobs_to_return = {{"weights.bin", 100}}; // non-empty: a stray download would be visible + auto* mock_raw = mock.get(); + manager.SetBlobDownloader(std::move(mock)); + + ModelInfo info; + info.model_id = "wait-model:1"; + info.name = "wait-model"; + info.uri = "azureml://registries/test/models/wait-model/versions/1"; + info.string_properties[FOUNDRY_LOCAL_MODEL_PROP_PUBLISHER_STR] = "Pub"; + + // Simulate another process holding the model-directory lock mid-download. + auto model_dir = fs::path(tmpdir.string()) / "Pub" / "wait-model-1"; + fs::create_directories(model_dir); + auto held = CrossProcessFileLock::TryAcquireForDirectory(model_dir, fl::test::NullLog()); + ASSERT_NE(held, nullptr); + + std::atomic done{false}; + std::string result; + std::thread worker([&] { result = manager.DownloadModel(info); done.store(true); }); + + // The call must block on the cross-process lock rather than proceed to download. + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_FALSE(done.load()) << "DownloadModel should block while another process holds the lock"; + + // The "other process" finishes: publish inference_model.json, then release the lock. + { + std::ofstream(model_dir / "inference_model.json") << "{}"; + } + held.reset(); + + worker.join(); + + EXPECT_TRUE(done.load()); + EXPECT_EQ(result, model_dir.string()); + EXPECT_TRUE(mock_raw->downloaded_blobs.empty()) + << "Model became available while waiting; the post-lock recheck must skip the download"; +} + // HasInferenceModelJson must return false instead of throwing when the path // it's asked about is not a directory (e.g. a regular file). Previously the // underlying directory_iterator would throw filesystem_error. @@ -1298,3 +1545,379 @@ TEST(DownloadManagerTest, AcceptsNormalModelIdAndPublisher) { EXPECT_NO_THROW(manager.IsModelCached(info)); EXPECT_FALSE(manager.IsModelCached(info)); } + +// ======================================================================== +// AzureBlobDownloader resume + cancel-cascade tests +// Use a subclass that overrides the protected GetBlobSize / DownloadChunkStreaming +// virtuals to bypass the real Azure SDK and simulate per-chunk behavior. +// ======================================================================== + +namespace { + +/// Test double for AzureBlobDownloader. Overrides the protected virtuals so +/// chunked-download orchestration can be exercised without network I/O. +class FakeChunkAzureDownloader : public AzureBlobDownloader { + public: + int64_t blob_size = 0; + + /// Per-call hook. Receives the chunk offset and size plus a `sink` callback + /// that forwards bytes to the file writer. Allowed to: + /// - call `sink` zero or more times with strictly contiguous, cumulative + /// `size`-byte ranges to simulate a successful chunk + /// - throw to simulate a transient failure (sink calls so far still hit disk) + /// - sleep / poll cancellation + std::function& sink, + const std::function& is_cancelled)> + chunk_hook; + + std::atomic chunk_call_count{0}; + std::mutex offsets_mutex; + std::vector requested_offsets; + + using AzureBlobDownloader::AzureBlobDownloader; + + // AzureBlobDownloader now requires a logger reference. Tests don't care about + // diagnostics, so default-construct against the shared null logger to keep the + // many `FakeChunkAzureDownloader d;` sites terse. + FakeChunkAzureDownloader() : AzureBlobDownloader(fl::test::NullLog()) {} + + protected: + int64_t GetBlobSize(ChunkContext& /*ctx*/) override { return blob_size; } + + void DownloadChunkStreaming(ChunkContext& ctx, int64_t offset, int64_t size, + std::vector& scratch, + const std::function& sink) override { + chunk_call_count.fetch_add(1); + { + std::lock_guard lock(offsets_mutex); + requested_offsets.push_back(offset); + } + if (chunk_hook) { + chunk_hook(offset, size, sink, [this, &ctx]() { return IsCancellationRequested(ctx); }); + return; + } + // Default: stream the chunk to the sink in scratch-sized pieces, filled + // with the low byte of the offset for verification. + if (scratch.size() < 64 * 1024) { + scratch.resize(64 * 1024); + } + int64_t remaining = size; + while (remaining > 0) { + size_t to_emit = + static_cast(std::min(remaining, static_cast(scratch.size()))); + std::fill_n(scratch.begin(), to_emit, static_cast(offset & 0xFF)); + sink(scratch.data(), to_emit); + remaining -= static_cast(to_emit); + } + } +}; + +} // namespace + +TEST(AzureBlobDownloaderResumeTest, SkipsChunksAlreadyMarkedCompleteInSidecar) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + // Pre-allocate the data file so the downloader takes the resume path. + { + std::ofstream f(local, std::ios::binary); + f.seekp(kBlobSize - 1); + f.put('\0'); + } + // Pre-write a sidecar: chunks 0..4 done, 5..9 pending. + { + auto state = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i = 0; i < 5; ++i) { + state->MarkChunkComplete(i); + } + state->SaveState(fl::test::NullLog()); + } + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/2); + + EXPECT_EQ(d.chunk_call_count.load(), 5); + std::sort(d.requested_offsets.begin(), d.requested_offsets.end()); + std::vector expected{5 * int64_t{kChunkSize}, 6 * int64_t{kChunkSize}, + 7 * int64_t{kChunkSize}, 8 * int64_t{kChunkSize}, + 9 * int64_t{kChunkSize}}; + EXPECT_EQ(d.requested_offsets, expected); + + // Sidecar should be gone on full success. + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); +} + +TEST(AzureBlobDownloaderResumeTest, IgnoresSidecarWhenDataFileTruncated) { + // A valid sidecar marks chunks complete, but the data file was truncated (e.g. + // an external cleanup) while the sidecar survived. The downloader must not trust + // the sidecar — those "completed" chunks are no longer on disk — and must + // re-download every chunk rather than leave them as zeros. + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + // Sidecar claims chunks 0..4 are done. + { + auto state = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i = 0; i < 5; ++i) { + state->MarkChunkComplete(i); + } + state->SaveState(fl::test::NullLog()); + } + // ...but the data file is truncated, far smaller than kBlobSize. + { + std::ofstream f(local, std::ios::binary | std::ios::trunc); + f << "truncated"; + } + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/2); + + // The stale sidecar is ignored: every chunk is downloaded, not just 5..9. + EXPECT_EQ(d.chunk_call_count.load(), kNumChunks); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + EXPECT_EQ(fs::file_size(local), static_cast(kBlobSize)); +} + +TEST(AzureBlobDownloaderResumeTest, DownloadsAllChunksWhenSidecarMissing) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 4; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/4); + + EXPECT_EQ(d.chunk_call_count.load(), kNumChunks); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + // Local file is pre-allocated to blob_size during the first pass. + EXPECT_TRUE(fs::exists(local)); + EXPECT_EQ(fs::file_size(local), static_cast(kBlobSize)); +} + +TEST(AzureBlobDownloaderResumeTest, PersistsSidecarOnChunkFailure) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + // Fail when we see the offset of chunk 4 (specifically chosen so several + // chunks land before the failing one across threads). + constexpr int64_t kFailOffset = 4 * int64_t{kChunkSize}; + d.chunk_hook = [&](int64_t offset, int64_t size, + const std::function& sink, + const std::function& /*is_cancelled*/) { + if (offset == kFailOffset) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "simulated chunk failure"); + } + std::vector buf(static_cast(size), static_cast(offset & 0xFF)); + sink(buf.data(), buf.size()); + }; + + EXPECT_THROW( + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/2), + fl::Exception); + + // The sidecar should be persisted so a subsequent call can resume. + EXPECT_TRUE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + + // Verify the persisted sidecar records partial progress — some chunks completed + // before the failure, but not all — so a future resume can skip the ones already + // done and re-fetch only the rest. + auto retry_state = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks, + fl::test::NullLog()); + ASSERT_NE(retry_state, nullptr); + EXPECT_GT(retry_state->completed_count, 0); + EXPECT_LT(retry_state->completed_count, kNumChunks); +} + +// Regression: the sidecar must reach disk before the data file is pre-allocated, +// not only after save_interval chunks. Open() pre-allocates the file to full +// size, and IsDownloadNeeded treats "full-size data file + no sidecar" as a +// completed download. So a crash in the window between pre-allocation and the +// first periodic save would otherwise leave a full-size, empty file that the +// next run skips — silently serving zeros. Verify a sidecar is already present +// the moment the first chunk is requested. +TEST(AzureBlobDownloaderResumeTest, SidecarExistsBeforeFirstChunkCompletes) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 100; // far above the per-save chunk interval + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + auto sidecar = BlobDownloadState::GetStateFilePath(local); + std::atomic recorded{false}; + std::atomic sidecar_present_at_first_chunk{false}; + d.chunk_hook = [&](int64_t /*offset*/, int64_t /*size*/, + const std::function& /*sink*/, + const std::function&) { + if (!recorded.exchange(true)) { + // First chunk callback: CreateNew + the initial SaveState + Open() have + // all run, so the sidecar must already exist. Abort before any periodic + // save to mimic an early interruption. + sidecar_present_at_first_chunk.store(fs::exists(sidecar)); + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "stop after first chunk"); + } + }; + + EXPECT_THROW(d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/1), + fl::Exception); + + EXPECT_TRUE(sidecar_present_at_first_chunk.load()) + << "Sidecar must exist before any chunk completes so an early crash stays resumable."; + EXPECT_TRUE(fs::exists(sidecar)); + EXPECT_TRUE(fs::exists(local)); + EXPECT_EQ(fs::file_size(local), static_cast(kBlobSize)); +} + +TEST(AzureBlobDownloaderResumeTest, CleansUpSidecarOnEmptyBlob) { + TempDir tmpdir; + auto local = tmpdir.path() / "empty.bin"; + // Plant a stale sidecar. + { + std::ofstream f(BlobDownloadState::GetStateFilePath(local), std::ios::binary); + f << "stale"; + } + + FakeChunkAzureDownloader d; + d.blob_size = 0; // empty + + d.DownloadBlob(/*sas_uri=*/"", "empty", local.string(), /*max_concurrency=*/4); + + EXPECT_TRUE(fs::exists(local)); + EXPECT_EQ(fs::file_size(local), 0u); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + EXPECT_EQ(d.chunk_call_count.load(), 0); +} + +TEST(AzureBlobDownloaderResumeTest, ChunkFailureCancelsInFlightPeersFast) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + constexpr int64_t kFailOffset = 4 * int64_t{kChunkSize}; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + // The failing chunk throws fast. Every other chunk sleeps for up to 5 s in + // 50-ms slices, polling cancellation. If linked cancellation works, they + // observe it within one slice of the failure and exit promptly. + d.chunk_hook = [](int64_t offset, int64_t size, + const std::function& sink, + const std::function& is_cancelled) { + if (offset == kFailOffset) { + // Give other workers a moment to enter their sleep loop before we throw, + // so we're meaningfully testing the cancel-while-in-flight path. + std::this_thread::sleep_for(std::chrono::milliseconds(75)); + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "simulated chunk failure"); + } + for (int i = 0; i < 100; ++i) { + if (is_cancelled()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "cancelled mid-chunk"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + std::vector buf(static_cast(size), 0); + sink(buf.data(), buf.size()); + }; + + auto start = std::chrono::steady_clock::now(); + EXPECT_THROW( + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/kNumChunks), + fl::Exception); + auto elapsed = std::chrono::steady_clock::now() - start; + auto elapsed_ms = std::chrono::duration_cast(elapsed).count(); + + // Without cancellation, the slow chunks would sleep ~5 s. With it, they + // should all exit within a few hundred ms of the failure (well under 2 s). + EXPECT_LT(elapsed_ms, 2000) + << "Cancel-cascade should drain in-flight peers fast; took " << elapsed_ms << " ms"; +} + +TEST(AzureBlobDownloaderResumeTest, UserCancelDrainsInFlightPeersFast) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + // Chunk 0 is the cancel trigger; chunks 1..9 are the in-flight peers. The peers + // announce themselves and then sleep up to 5 s in 50-ms slices, polling the + // Azure-context cancellation. Chunk 0 waits until every peer is parked in that + // sleep loop before it completes, so no peer is at the worker top-of-loop to + // observe the shared cancel flag directly -- the only way they can exit + // promptly is the azure_ctx.Cancel() driven by the user-cancel throw. + std::atomic peers_parked{0}; + d.chunk_hook = [&peers_parked](int64_t offset, int64_t size, + const std::function& sink, + const std::function& is_cancelled) { + if (offset == 0) { + for (int i = 0; i < 400 && peers_parked.load() < kNumChunks - 1; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + std::vector buf(static_cast(size), 0); + sink(buf.data(), buf.size()); + return; + } + peers_parked.fetch_add(1); + for (int i = 0; i < 100; ++i) { + if (is_cancelled()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "cancelled mid-chunk"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + std::vector buf(static_cast(size), 0); + sink(buf.data(), buf.size()); + }; + + // Mirror per_chunk_progress: the first progress callback cancels by setting the + // shared flag and throwing. + std::atomic cancelled{false}; + BlobBytesWrittenFn cancel_on_first_progress = [&cancelled](int64_t /*bytes*/) { + cancelled.store(true, std::memory_order_relaxed); + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled by user callback return value"); + }; + + auto start = std::chrono::steady_clock::now(); + EXPECT_THROW(d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/kNumChunks, + cancel_on_first_progress, &cancelled), + fl::Exception); + auto elapsed = std::chrono::steady_clock::now() - start; + auto elapsed_ms = std::chrono::duration_cast(elapsed).count(); + + // Without routing the user-cancel throw through azure_ctx.Cancel(), the parked + // peers would each sleep their full ~5 s before noticing. With it, they exit + // within a slice or two (well under 2 s). + EXPECT_LT(elapsed_ms, 2000) + << "User-cancel should drain in-flight peers fast; took " << elapsed_ms << " ms"; +} diff --git a/sdk_v2/cpp/test/internal_api/file_writer_test.cc b/sdk_v2/cpp/test/internal_api/file_writer_test.cc new file mode 100644 index 00000000..c685506e --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/file_writer_test.cc @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Tests for the FileWriter backing AzureBlobDownloader's chunked writes: +// pre-allocation, resume preservation, and single-thread + concurrent +// disjoint-range positional writes. + +#include "download/file_writer.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; +using namespace fl; + +namespace { + +class TempPath { + public: + TempPath() { + auto base = fs::temp_directory_path(); + std::random_device rd; + std::uniform_int_distribution dist; + path_ = base / ("file_writer_test_" + std::to_string(dist(rd)) + ".bin"); + } + ~TempPath() { + std::error_code ec; + fs::remove(path_, ec); + } + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +} // namespace + +TEST(FileWriterTest, OpenCreatesFileAtRequestedSize) { + TempPath p; + FileWriter w; + w.Open(p.path(), 4096); + w.Close(); + EXPECT_TRUE(fs::exists(p.path())); + EXPECT_EQ(fs::file_size(p.path()), 4096u); +} + +TEST(FileWriterTest, OpenPreservesExistingFileAtSameSize) { + TempPath p; + // Pre-write a sentinel byte the writer must NOT overwrite. + { + std::ofstream f(p.path(), std::ios::binary); + f.seekp(1023); + f.put('\0'); + } + // Plant a known byte at offset 100. + { + std::fstream f(p.path(), std::ios::binary | std::ios::in | std::ios::out); + f.seekp(100); + f.put(static_cast(0xAB)); + } + + FileWriter w; + w.Open(p.path(), 1024); // same size -> must not truncate + w.Close(); + + // Sentinel byte should still be there. + std::ifstream f(p.path(), std::ios::binary); + f.seekg(100); + int byte = f.get(); + EXPECT_EQ(byte, 0xAB); +} + +TEST(FileWriterTest, OpenRecreatesFileWhenSizeDiffers) { + TempPath p; + { + std::ofstream f(p.path(), std::ios::binary); + f.seekp(100); + f.put(static_cast(0xCD)); + } + EXPECT_EQ(fs::file_size(p.path()), 101u); + + FileWriter w; + w.Open(p.path(), 4096); + w.Close(); + EXPECT_EQ(fs::file_size(p.path()), 4096u); +} + +TEST(FileWriterTest, SingleThreadWriteAt) { + TempPath p; + FileWriter w; + w.Open(p.path(), 1024); + + std::vector data(256, 0xEF); + w.WriteAt(512, data.data(), data.size()); + w.Close(); + + std::ifstream f(p.path(), std::ios::binary); + std::vector contents((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + ASSERT_EQ(contents.size(), 1024u); + for (size_t i = 512; i < 768; ++i) { + EXPECT_EQ(contents[i], 0xEF) << "byte " << i; + } +} + +TEST(FileWriterTest, ConcurrentDisjointWritesProduceCorrectFile) { + TempPath p; + constexpr int kThreads = 8; + constexpr int kRegionSize = 256 * 1024; // 256 KB per thread + constexpr int kPieceSize = 16 * 1024; // 16 KB per WriteAt + constexpr int64_t kTotalSize = int64_t{kThreads} * kRegionSize; + static_assert(kRegionSize % kPieceSize == 0, ""); + + FileWriter w; + w.Open(p.path(), kTotalSize); + + std::atomic started{0}; + std::vector workers; + workers.reserve(kThreads); + for (int t = 0; t < kThreads; ++t) { + workers.emplace_back([&, t]() { + std::vector piece(kPieceSize, static_cast(t + 1)); + started.fetch_add(1); + while (started.load() < kThreads) { + // tiny spin to encourage concurrent dispatch + } + const int64_t base = int64_t{t} * kRegionSize; + for (int i = 0; i < kRegionSize / kPieceSize; ++i) { + w.WriteAt(base + int64_t{i} * kPieceSize, piece.data(), piece.size()); + } + }); + } + for (auto& th : workers) th.join(); + w.Close(); + + std::ifstream f(p.path(), std::ios::binary); + std::vector contents((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + ASSERT_EQ(contents.size(), static_cast(kTotalSize)); + for (int t = 0; t < kThreads; ++t) { + const uint8_t expected = static_cast(t + 1); + for (int64_t i = 0; i < kRegionSize; ++i) { + const auto idx = static_cast(int64_t{t} * kRegionSize + i); + if (contents[idx] != expected) { + FAIL() << "mismatch at offset " << idx << " (thread " << t << ", expected " + << static_cast(expected) << ", got " << static_cast(contents[idx]) << ")"; + } + } + } +} diff --git a/sdk_v2/cpp/test/sdk_api/download_test.cc b/sdk_v2/cpp/test/sdk_api/download_test.cc index 23c6ffcc..af0592d5 100644 --- a/sdk_v2/cpp/test/sdk_api/download_test.cc +++ b/sdk_v2/cpp/test/sdk_api/download_test.cc @@ -77,7 +77,7 @@ TEST_F(DISABLED_DownloadFixture, RemoveAndRedownloadSmallestModel) { std::vector progress_values; target->Download([&progress_values](float pct) { progress_values.push_back(pct); - return true; // Continue downloading. + return 0; // Continue downloading (0 = continue, non-zero = cancel). }); EXPECT_TRUE(target->IsCached()) @@ -122,7 +122,7 @@ TEST_F(DISABLED_DownloadFixture, DownloadAlreadyCachedModelIsNoOp) { std::vector progress_values; model->Download([&progress_values](float pct) { progress_values.push_back(pct); - return true; + return 0; // 0 = continue, non-zero = cancel. }); EXPECT_TRUE(model->IsCached()); @@ -131,3 +131,88 @@ TEST_F(DISABLED_DownloadFixture, DownloadAlreadyCachedModelIsNoOp) { ASSERT_FALSE(progress_values.empty()); EXPECT_FLOAT_EQ(progress_values.back(), 100.0f); } + +TEST_F(DISABLED_DownloadFixture, ResumesPartialDownloadAfterCancel) { + // Live resume check: cancel a real download partway, then re-run and confirm the + // .dlstate sidecar drove a *partial* resume rather than a fresh re-download. + // Pick the smallest uncached CPU model (same selection as the redownload test). + foundry_local::IModel* target = nullptr; + int64_t target_size = std::numeric_limits::max(); + for (const auto& m : model_list()) { + if (m->IsLoaded()) { + continue; + } + for (const auto& v : m->GetVariants()) { + auto vi = v->GetInfo(); + if (vi.DeviceType() != FOUNDRY_LOCAL_DEVICE_CPU) { + continue; + } + int64_t size = vi.FilesizeMb().value_or(0); + if (size > 0 && size < target_size) { + target_size = size; + m->SelectVariant(*v); + target = m.get(); + } + } + } + ASSERT_NE(target, nullptr) << "No unloaded CPU model found in catalog"; + + if (target->IsCached()) { + target->RemoveFromCache(); + } + ASSERT_FALSE(target->IsCached()); + + // First attempt: cancel once aggregate progress passes ~30%, leaving partial + // data plus its .dlstate sidecar(s) on disk. The progress callback returns 0 to + // continue and non-zero to cancel. + float cancel_pct = -1.0f; + bool threw = false; + try { + target->Download([&cancel_pct](float pct) -> int { + if (pct >= 30.0f) { + cancel_pct = pct; + return 1; // cancel + } + return 0; // continue + }); + } catch (const std::exception&) { + threw = true; // cancellation surfaces as a thrown error + } + ASSERT_GE(cancel_pct, 30.0f) << "download never reached the cancel threshold"; + EXPECT_TRUE(threw) << "a cancelled download should surface an error"; + EXPECT_FALSE(target->IsCached()) << "a cancelled download must not report cached"; + + // Second attempt: let it finish, capturing every reported percentage. + std::vector resume_progress; + target->Download([&resume_progress](float pct) -> int { + resume_progress.push_back(pct); + return 0; + }); + + ASSERT_FALSE(resume_progress.empty()); + + // DownloadManager::DownloadModel always emits a 0% heartbeat (progress_cb(0.0f)) + // before the transfer starts, which Model::Download forwards unchanged, so + // resume_progress.front() is that heartbeat -- not the resumed percentage. Skip + // the leading zero(s); the first non-zero sample is the initial on-disk + // reflection DownloadBlobsToDirectory emits from the bytes already present. + float first_real = 0.0f; + for (float pct : resume_progress) { + if (pct > 0.0f) { + first_real = pct; + break; + } + } + ASSERT_GT(first_real, 0.0f) << "resume produced no real progress past the 0% heartbeat"; + + // A sidecar-driven partial resume reports its first real progress at roughly the + // bytes already on disk (~the cancel point), so it lands well above the tiny + // first-chunk fraction a fresh re-download would start from. The threshold is + // intentionally lenient (sidecar saves are bounded at ~16 MB granularity); tune + // it if a different model than the smallest CPU one is selected. + constexpr float kMinResumeProgressPct = 10.0f; + EXPECT_GE(first_real, kMinResumeProgressPct) + << "first real progress " << first_real << "% looks like a fresh re-download, not a resume"; + EXPECT_FLOAT_EQ(resume_progress.back(), 100.0f); + EXPECT_TRUE(target->IsCached()); +}