diff --git a/projects/llm_framework/include/fst/accumulator.h b/projects/llm_framework/include/fst/accumulator.h new file mode 100644 index 00000000..5ae19247 --- /dev/null +++ b/projects/llm_framework/include/fst/accumulator.h @@ -0,0 +1,903 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes to accumulate arc weights. Useful for weight lookahead. + +#ifndef FST_ACCUMULATOR_H_ +#define FST_ACCUMULATOR_H_ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace fst { + +// This class accumulates arc weights using the semiring Plus(). +template +class DefaultAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DefaultAccumulator() {} + + DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {} + + void Init(const Fst &fst, bool copy = false) {} + + void SetState(StateId state) {} + + Weight Sum(Weight w, Weight v) { return Plus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + Adder adder(w); // maintains cumulative sum accurately + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) + adder.Add(aiter->Value().weight); + return adder.Sum(); + } + + constexpr bool Error() const { return false; } + + private: + DefaultAccumulator &operator=(const DefaultAccumulator &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() assuming an +// arc weight has a WeightConvert specialization to and from log64 weights. +template +class LogAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + LogAccumulator() {} + + LogAccumulator(const LogAccumulator &acc, bool safe = false) {} + + void Init(const Fst &fst, bool copy = false) {} + + void SetState(StateId s) {} + + Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + auto sum = w; + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + } + + constexpr bool Error() const { return false; } + + private: + Weight LogPlus(Weight w, Weight v) { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1))); + } + } + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + + LogAccumulator &operator=(const LogAccumulator &) = delete; +}; + +// Interface for shareable data for fast log accumulator copies. Holds pointers +// to data only, storage is provided by derived classes. +class FastLogAccumulatorData { + public: + FastLogAccumulatorData(int arc_limit, int arc_period) + : arc_limit_(arc_limit), + arc_period_(arc_period), + weights_ptr_(nullptr), + num_weights_(0), + weight_positions_ptr_(nullptr), + num_positions_(0) {} + + virtual ~FastLogAccumulatorData() {} + + // Cummulative weight per state for all states s.t. # of arcs > arc_limit_ + // with arcs in order. The first element per state is Log64Weight::Zero(). + const double *Weights() const { return weights_ptr_; } + + int NumWeights() const { return num_weights_; } + + // Maps from state to corresponding beginning weight position in weights_. + // osition -1 means no pre-computed weights for that state. + const int *WeightPositions() const { return weight_positions_ptr_; } + + int NumPositions() const { return num_positions_; } + + int ArcLimit() const { return arc_limit_; } + + int ArcPeriod() const { return arc_period_; } + + // Returns true if the data object is mutable and supports SetData(). + virtual bool IsMutable() const = 0; + + // Does not take ownership but may invalidate the contents of weights and + // weight_positions. + virtual void SetData(std::vector *weights, + std::vector *weight_positions) = 0; + + protected: + void Init(int num_weights, const double *weights, int num_positions, + const int *weight_positions) { + weights_ptr_ = weights; + num_weights_ = num_weights; + weight_positions_ptr_ = weight_positions; + num_positions_ = num_positions; + } + + private: + const int arc_limit_; + const int arc_period_; + const double *weights_ptr_; + int num_weights_; + const int *weight_positions_ptr_; + int num_positions_; + + FastLogAccumulatorData(const FastLogAccumulatorData &) = delete; + FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete; +}; + +// FastLogAccumulatorData with mutable storage; filled by +// FastLogAccumulator::Init. +class MutableFastLogAccumulatorData : public FastLogAccumulatorData { + public: + MutableFastLogAccumulatorData(int arc_limit, int arc_period) + : FastLogAccumulatorData(arc_limit, arc_period) {} + + bool IsMutable() const override { return true; } + + void SetData(std::vector *weights, + std::vector *weight_positions) override { + weights_.swap(*weights); + weight_positions_.swap(*weight_positions); + Init(weights_.size(), weights_.data(), weight_positions_.size(), + weight_positions_.data()); + } + + private: + std::vector weights_; + std::vector weight_positions_; + + MutableFastLogAccumulatorData(const MutableFastLogAccumulatorData &) = delete; + MutableFastLogAccumulatorData &operator=( + const MutableFastLogAccumulatorData &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() assuming an +// arc weight has a WeightConvert specialization to and from log64 weights. The +// member function Init(fst) has to be called to setup pre-computed weight +// information. +template +class FastLogAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) + : to_log_weight_(), + to_weight_(), + arc_limit_(arc_limit), + arc_period_(arc_period), + data_(std::make_shared(arc_limit, + arc_period)), + state_weights_(nullptr), + error_(false) {} + + explicit FastLogAccumulator(std::shared_ptr data) + : to_log_weight_(), + to_weight_(), + arc_limit_(data->ArcLimit()), + arc_period_(data->ArcPeriod()), + data_(data), + state_weights_(nullptr), + error_(false) {} + + FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false) + : to_log_weight_(), + to_weight_(), + arc_limit_(acc.arc_limit_), + arc_period_(acc.arc_period_), + data_(acc.data_), + state_weights_(nullptr), + error_(acc.error_) {} + + void SetState(StateId s) { + const auto *weights = data_->Weights(); + const auto *weight_positions = data_->WeightPositions(); + state_weights_ = nullptr; + if (s < data_->NumPositions()) { + const auto pos = weight_positions[s]; + if (pos >= 0) state_weights_ = &(weights[pos]); + } + } + + Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const { + if (error_) return Weight::NoWeight(); + auto sum = w; + // Finds begin and end of pre-stored weights. + ssize_t index_begin = -1; + ssize_t index_end = -1; + ssize_t stored_begin = end; + ssize_t stored_end = end; + if (state_weights_) { + index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0; + index_end = end / arc_period_; + stored_begin = index_begin * arc_period_; + stored_end = index_end * arc_period_; + } + // Computes sum before pre-stored weights. + if (begin < stored_begin) { + const auto pos_end = std::min(stored_begin, end); + aiter->Seek(begin); + for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + } + // Computes sum between pre-stored weights. + if (stored_begin < stored_end) { + const auto f1 = state_weights_[index_end]; + const auto f2 = state_weights_[index_begin]; + if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2)); + // Commented out for efficiency; adds Zero(). + /* + else { + // explicitly computes if cumulative sum lacks precision + aiter->Seek(stored_begin); + for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + } + */ + } + // Computes sum after pre-stored weights. + if (stored_end < end) { + const auto pos_start = std::max(stored_begin, stored_end); + aiter->Seek(pos_start); + for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + } + return sum; + } + + template + void Init(const FST &fst, bool copy = false) { + if (copy || !data_->IsMutable()) return; + if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) { + FSTERROR() << "FastLogAccumulator: Initialization error"; + error_ = true; + return; + } + std::vector weights; + std::vector weight_positions; + weight_positions.reserve(CountStates(fst)); + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst.NumArcs(s) >= arc_limit_) { + auto sum = FloatLimits::PosInfinity(); + if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1); + weight_positions[s] = weights.size(); + weights.push_back(sum); + size_t narcs = 0; + ArcIterator aiter(fst, s); + aiter.SetFlags(kArcWeightValue | kArcNoCache, kArcFlags); + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + sum = LogPlus(sum, arc.weight); + // Stores cumulative weight distribution per arc_period_. + if (++narcs % arc_period_ == 0) weights.push_back(sum); + } + } + } + data_->SetData(&weights, &weight_positions); + } + + bool Error() const { return error_; } + + std::shared_ptr GetData() const { return data_; } + + private: + static double LogPosExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F + exp(-x)); + } + + static double LogMinusExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) const { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); + } + } + + double LogPlus(double f1, Weight v) const { + const auto f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits::PosInfinity()) { + return f2; + } else if (f1 > f2) { + return f2 - LogPosExp(f1 - f2); + } else { + return f1 - LogPosExp(f2 - f1); + } + } + + // Assumes f1 < f2. + Weight LogMinus(double f1, double f2) const { + if (f2 == FloatLimits::PosInfinity()) { + return to_weight_(Log64Weight(f1)); + } else { + return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); + } + } + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state. + const ssize_t arc_period_; // Saves cumulative weights per arc_period_. + std::shared_ptr data_; + const double *state_weights_; + bool error_; + + FastLogAccumulator &operator=(const FastLogAccumulator &) = delete; +}; + +// Stores shareable data for cache log accumulator copies. All copies share the +// same cache. +template +class CacheLogAccumulatorData { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + CacheLogAccumulatorData(bool gc, size_t gc_limit) + : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} + + CacheLogAccumulatorData(const CacheLogAccumulatorData &data) + : cache_gc_(data.cache_gc_), + cache_limit_(data.cache_limit_), + cache_size_(0) {} + + bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } + + std::vector *GetWeights(StateId s) { + auto it = cache_.find(s); + if (it != cache_.end()) { + it->second.recent = true; + return it->second.weights.get(); + } else { + return nullptr; + } + } + + void AddWeights(StateId s, std::vector *weights) { + if (cache_gc_ && cache_size_ >= cache_limit_) GC(false); + cache_.insert(std::make_pair(s, CacheState(weights, true))); + if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double); + } + + private: + // Cached information for a given state. + struct CacheState { + std::unique_ptr> weights; // Accumulated weights. + bool recent; // Has this state been accessed since last GC? + + CacheState(std::vector *weights, bool recent) + : weights(weights), recent(recent) {} + }; + + // Garbage collect: Deletes from cache states that have not been accessed + // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of + // 'cache_limit_'. If it does not free enough memory, start deleting + // recently accessed states. + void GC(bool free_recent) { + auto cache_target = (2 * cache_limit_) / 3 + 1; + auto it = cache_.begin(); + while (it != cache_.end() && cache_size_ > cache_target) { + auto &cs = it->second; + if (free_recent || !cs.recent) { + cache_size_ -= cs.weights->capacity() * sizeof(double); + cache_.erase(it++); + } else { + cs.recent = false; + ++it; + } + } + if (!free_recent && cache_size_ > cache_target) GC(true); + } + + std::unordered_map cache_; // Cache. + bool cache_gc_; // Enables garbage collection. + size_t cache_limit_; // # of bytes cached. + size_t cache_size_; // # of bytes allowed before GC. + + CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() has a +// WeightConvert specialization to and from log64 weights. It is similar to the +// FastLogAccumator. However here, the accumulated weights are pre-computed and +// stored only for the states that are visited. The member function Init(fst) +// has to be called to setup this accumulator. +template +class CacheLogAccumulator { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, + size_t gc_limit = 10 * 1024 * 1024) + : arc_limit_(arc_limit), + data_(std::make_shared>(gc, gc_limit)), + s_(kNoStateId), + error_(false) {} + + CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false) + : arc_limit_(acc.arc_limit_), + fst_(acc.fst_ ? acc.fst_->Copy() : nullptr), + data_(safe ? std::make_shared>(*acc.data_) + : acc.data_), + s_(kNoStateId), + error_(acc.error_) {} + + // Argument arc_limit specifies the minimum number of arcs to pre-compute. + void Init(const Fst &fst, bool copy = false) { + if (!copy && fst_) { + FSTERROR() << "CacheLogAccumulator: Initialization error"; + error_ = true; + return; + } + fst_.reset(fst.Copy()); + } + + void SetState(StateId s, int depth = 0) { + if (s == s_) return; + s_ = s; + if (data_->CacheDisabled() || error_) { + weights_ = nullptr; + return; + } + if (!fst_) { + FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized"; + error_ = true; + weights_ = nullptr; + return; + } + weights_ = data_->GetWeights(s); + if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) { + weights_ = new std::vector; + weights_->reserve(fst_->NumArcs(s) + 1); + weights_->push_back(FloatLimits::PosInfinity()); + data_->AddWeights(s, weights_); + } + } + + Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (weights_ == nullptr) { + auto sum = w; + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + } else { + Extend(end, aiter); + const auto &f1 = (*weights_)[end]; + const auto &f2 = (*weights_)[begin]; + if (f1 < f2) { + return LogPlus(w, LogMinus(f1, f2)); + } else { + // Commented out for efficiency; adds Zero(). + /* + auto sum = w; + // Explicitly computes if cumulative sum lacks precision. + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + */ + return w; + } + } + } + + // Returns first position from aiter->Position() whose accumulated + // value is greater or equal to w (w.r.t. Zero() < One()). The + // iterator may be repositioned. + template + size_t LowerBound(Weight w, ArcIter *aiter) { + const auto f = to_log_weight_(w).Value(); + auto pos = aiter->Position(); + if (weights_) { + Extend(fst_->NumArcs(s_), aiter); + return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), + f, std::greater()) - + weights_->begin() - 1; + } else { + size_t n = 0; + auto x = FloatLimits::PosInfinity(); + for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { + x = LogPlus(x, aiter->Value().weight); + if (n >= pos && x <= f) break; + } + return n; + } + } + + bool Error() const { return error_; } + + private: + double LogPosExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F + exp(-x)); + } + + double LogMinusExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); + } + } + + double LogPlus(double f1, Weight v) { + const auto f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits::PosInfinity()) { + return f2; + } else if (f1 > f2) { + return f2 - LogPosExp(f1 - f2); + } else { + return f1 - LogPosExp(f2 - f1); + } + } + + // Assumes f1 < f2. + Weight LogMinus(double f1, double f2) { + if (f2 == FloatLimits::PosInfinity()) { + return to_weight_(Log64Weight(f1)); + } else { + return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); + } + } + + // Extends weights up to index 'end'. + template + void Extend(ssize_t end, ArcIter *aiter) { + if (weights_->size() <= end) { + for (aiter->Seek(weights_->size() - 1); weights_->size() <= end; + aiter->Next()) { + weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight)); + } + } + } + + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + ssize_t arc_limit_; // Minimum # of arcs to cache a state. + std::vector *weights_; // Accumulated weights for cur. state. + std::unique_ptr> fst_; // Input FST. + std::shared_ptr> data_; // Cache data. + StateId s_; // Current state. + bool error_; +}; + +// Stores shareable data for replace accumulator copies. +template +class ReplaceAccumulatorData { + public: + using Arc = typename Accumulator::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + + ReplaceAccumulatorData() : state_table_(nullptr) {} + + explicit ReplaceAccumulatorData( + const std::vector &accumulators) + : state_table_(nullptr) { + accumulators_.reserve(accumulators.size()); + for (const auto accumulator : accumulators) { + accumulators_.emplace_back(accumulator); + } + } + + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + state_table_ = state_table; + accumulators_.resize(fst_tuples.size()); + for (Label i = 0; i < accumulators_.size(); ++i) { + if (!accumulators_[i]) { + accumulators_[i].reset(new Accumulator()); + accumulators_[i]->Init(*(fst_tuples[i].second)); + } + fst_array_.emplace_back(fst_tuples[i].second->Copy()); + } + } + + const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); } + + Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); } + + const Fst *GetFst(size_t i) const { return fst_array_[i].get(); } + + private: + const StateTable *state_table_; + std::vector> accumulators_; + std::vector>> fst_array_; +}; + +// This class accumulates weights in a ReplaceFst. The 'Init' method takes as +// input the argument used to build the ReplaceFst and the ReplaceFst state +// table. It uses accumulators of type 'Accumulator' in the underlying FSTs. +template > +class ReplaceAccumulator { + public: + using Arc = typename Accumulator::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + using Weight = typename Arc::Weight; + + ReplaceAccumulator() + : init_(false), + data_(std::make_shared< + ReplaceAccumulatorData>()), + error_(false) {} + + explicit ReplaceAccumulator(const std::vector &accumulators) + : init_(false), + data_(std::make_shared>( + accumulators)), + error_(false) {} + + ReplaceAccumulator(const ReplaceAccumulator &acc, + bool safe = false) + : init_(acc.init_), data_(acc.data_), error_(acc.error_) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator"; + } + if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported"; + } + + // Does not take ownership of the state table, the state table is owned by + // the ReplaceFst. + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + init_ = true; + data_->Init(fst_tuples, state_table); + } + + // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to + // be initialized by calling the Init method above before being passed to + // LookAheadMatcher. + // + // TODO(allauzen): Revisit this. Consider creating a method + // Init(const ReplaceFst&, bool) and using friendship to get access + // to the innards of ReplaceFst. + void Init(const Fst &fst, bool copy = false) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be" + << " initialized before being passed to LookAheadMatcher"; + error_ = true; + } + } + + void SetState(StateId s) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized"; + error_ = true; + return; + } + auto tuple = data_->GetTuple(s); + fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based. + data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); + if ((tuple.prefix_id != 0) && + (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { + offset_ = 1; + offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); + } else { + offset_ = 0; + offset_weight_ = Weight::Zero(); + } + aiter_.reset( + new ArcIterator>(*data_->GetFst(fst_id_), tuple.fst_state)); + } + + Weight Sum(Weight w, Weight v) { + if (error_) return Weight::NoWeight(); + return data_->GetAccumulator(fst_id_)->Sum(w, v); + } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (error_) return Weight::NoWeight(); + auto sum = begin == end ? Weight::Zero() + : data_->GetAccumulator(fst_id_)->Sum( + w, aiter_.get(), begin ? begin - offset_ : 0, + end - offset_); + if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum); + return sum; + } + + bool Error() const { return error_; } + + private: + bool init_; + std::shared_ptr> data_; + Label fst_id_; + size_t offset_; + Weight offset_weight_; + std::unique_ptr>> aiter_; + bool error_; +}; + +// SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it +// are always thread-safe copies. +template +class SafeReplaceAccumulator { + public: + using Arc = typename Accumulator::Arc; + using StateId = typename Arc::StateId; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + + SafeReplaceAccumulator() {} + + SafeReplaceAccumulator(const SafeReplaceAccumulator ©, bool safe) + : SafeReplaceAccumulator(copy) {} + + explicit SafeReplaceAccumulator( + const std::vector &accumulators) { + for (const auto &accumulator : accumulators) { + accumulators_.emplace_back(accumulator, true); + } + } + + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + state_table_ = state_table; + for (Label i = 0; i < fst_tuples.size(); ++i) { + if (i == accumulators_.size()) { + accumulators_.resize(accumulators_.size() + 1); + accumulators_[i].Init(*(fst_tuples[i].second)); + } + fst_array_.emplace_back(fst_tuples[i].second->Copy(true)); + } + init_ = true; + } + + void Init(const Fst &fst, bool copy = false) { + if (!init_) { + FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be" + << " initialized before being passed to LookAheadMatcher"; + error_ = true; + } + } + + void SetState(StateId s) { + auto tuple = state_table_->Tuple(s); + fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based + GetAccumulator(fst_id_)->SetState(tuple.fst_state); + offset_ = 0; + offset_weight_ = Weight::Zero(); + const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state); + if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) { + offset_ = 1; + offset_weight_ = final_weight; + } + aiter_.Set(*GetFst(fst_id_), tuple.fst_state); + } + + Weight Sum(Weight w, Weight v) { + if (error_) return Weight::NoWeight(); + return GetAccumulator(fst_id_)->Sum(w, v); + } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (error_) return Weight::NoWeight(); + if (begin == end) return Weight::Zero(); + auto sum = GetAccumulator(fst_id_)->Sum( + w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_); + if (begin == 0 && end != 0 && offset_ > 0) { + sum = Sum(offset_weight_, sum); + } + return sum; + } + + bool Error() const { return error_; } + + private: + class ArcIteratorPtr { + public: + ArcIteratorPtr() {} + + ArcIteratorPtr(const ArcIteratorPtr ©) {} + + void Set(const Fst &fst, StateId state_id) { + ptr_.reset(new ArcIterator>(fst, state_id)); + } + + ArcIterator> *get() { return ptr_.get(); } + + private: + std::unique_ptr>> ptr_; + }; + + Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; } + + const Fst *GetFst(size_t i) const { return fst_array_[i].get(); } + + const StateTable *state_table_; + std::vector accumulators_; + std::vector>> fst_array_; + ArcIteratorPtr aiter_; + bool init_ = false; + bool error_ = false; + Label fst_id_; + size_t offset_; + Weight offset_weight_; +}; + +} // namespace fst + +#endif // FST_ACCUMULATOR_H_ diff --git a/projects/llm_framework/include/fst/add-on.h b/projects/llm_framework/include/fst/add-on.h new file mode 100644 index 00000000..4a95111f --- /dev/null +++ b/projects/llm_framework/include/fst/add-on.h @@ -0,0 +1,248 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST implementation class to attach an arbitrary object with a read/write +// method to an FST and its file representation. The FST is given a new type +// name. + +#ifndef FST_ADD_ON_H_ +#define FST_ADD_ON_H_ + +#include +#include +#include +#include + +#include + +#include + + +namespace fst { + +// Identifies stream data as an add-on FST. +static constexpr int32 kAddOnMagicNumber = 446681434; + +// Nothing to save. +class NullAddOn { + public: + NullAddOn() {} + + static NullAddOn *Read(std::istream &strm, const FstReadOptions &opts) { + return new NullAddOn(); + } + + bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const { + return true; + } +}; + +// Create a new add-on from a pair of add-ons. +template +class AddOnPair { + public: + // Argument reference count incremented. + AddOnPair(std::shared_ptr a1, std::shared_ptr a2) + : a1_(std::move(a1)), a2_(std::move(a2)) {} + + const A1 *First() const { return a1_.get(); } + + const A2 *Second() const { return a2_.get(); } + + std::shared_ptr SharedFirst() const { return a1_; } + + std::shared_ptr SharedSecond() const { return a2_; } + + static AddOnPair *Read(std::istream &istrm, + const FstReadOptions &opts) { + A1 *a1 = nullptr; + bool have_addon1 = false; + ReadType(istrm, &have_addon1); + if (have_addon1) a1 = A1::Read(istrm, opts); + + A2 *a2 = nullptr; + bool have_addon2 = false; + ReadType(istrm, &have_addon2); + if (have_addon2) a2 = A2::Read(istrm, opts); + + return new AddOnPair(std::shared_ptr(a1), + std::shared_ptr(a2)); + } + + bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const { + bool have_addon1 = a1_ != nullptr; + WriteType(ostrm, have_addon1); + if (have_addon1) a1_->Write(ostrm, opts); + bool have_addon2 = a2_ != nullptr; + WriteType(ostrm, have_addon2); + if (have_addon2) a2_->Write(ostrm, opts); + return true; + } + + private: + std::shared_ptr a1_; + std::shared_ptr a2_; +}; + +namespace internal { + +// Adds an object of type T to an FST. T must support: +// +// T* Read(std::istream &); +// bool Write(std::ostream &); +// +// The resulting type is a new FST implementation. +template +class AddOnImpl : public FstImpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + + // We make a thread-safe copy of the FST by default since an FST + // implementation is expected to not share mutable data between objects. + AddOnImpl(const FST &fst, const string &type, + std::shared_ptr t = std::shared_ptr()) + : fst_(fst, true), t_(std::move(t)) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + // Conversion from const Fst & to F always copies the underlying + // implementation. + AddOnImpl(const Fst &fst, const string &type, + std::shared_ptr t = std::shared_ptr()) + : fst_(fst), t_(std::move(t)) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + // We make a thread-safe copy of the FST by default since an FST + // implementation is expected to not share mutable data between objects. + AddOnImpl(const AddOnImpl &impl) + : fst_(impl.fst_, true), t_(impl.t_) { + SetType(impl.Type()); + SetProperties(fst_.Properties(kCopyProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId s) const { return fst_.Final(s); } + + size_t NumArcs(StateId s) const { return fst_.NumArcs(s); } + + size_t NumInputEpsilons(StateId s) const { return fst_.NumInputEpsilons(s); } + + size_t NumOutputEpsilons(StateId s) const { + return fst_.NumOutputEpsilons(s); + } + + size_t NumStates() const { return fst_.NumStates(); } + + static AddOnImpl *Read(std::istream &strm, + const FstReadOptions &opts) { + FstReadOptions nopts(opts); + FstHeader hdr; + if (!nopts.header) { + hdr.Read(strm, nopts.source); + nopts.header = &hdr; + } + std::unique_ptr> impl( + new AddOnImpl(nopts.header->FstType())); + if (!impl->ReadHeader(strm, nopts, kMinFileVersion, &hdr)) return nullptr; + impl.reset(); + int32 magic_number = 0; + ReadType(strm, &magic_number); // Ensures this is an add-on FST. + if (magic_number != kAddOnMagicNumber) { + LOG(ERROR) << "AddOnImpl::Read: Bad add-on header: " << nopts.source; + return nullptr; + } + FstReadOptions fopts(opts); + fopts.header = nullptr; // Contained header was written out. + std::unique_ptr fst(FST::Read(strm, fopts)); + if (!fst) return nullptr; + std::shared_ptr t; + bool have_addon = false; + ReadType(strm, &have_addon); + if (have_addon) { // Reads add-on object if present. + t = std::shared_ptr(T::Read(strm, fopts)); + if (!t) return nullptr; + } + return new AddOnImpl(*fst, nopts.header->FstType(), t); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + FstWriteOptions nopts(opts); + nopts.write_isymbols = false; // Allows contained FST to hold any symbols. + nopts.write_osymbols = false; + WriteHeader(strm, nopts, kFileVersion, &hdr); + WriteType(strm, kAddOnMagicNumber); // Ensures this is an add-on FST. + FstWriteOptions fopts(opts); + fopts.write_header = true; // Forces writing contained header. + if (!fst_.Write(strm, fopts)) return false; + bool have_addon = !!t_; + WriteType(strm, have_addon); + // Writes add-on object if present. + if (have_addon) t_->Write(strm, opts); + return true; + } + + void InitStateIterator(StateIteratorData *data) const { + fst_.InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const { + fst_.InitArcIterator(s, data); + } + + FST &GetFst() { return fst_; } + + const FST &GetFst() const { return fst_; } + + const T *GetAddOn() const { return t_.get(); } + + std::shared_ptr GetSharedAddOn() const { return t_; } + + void SetAddOn(std::shared_ptr t) { t_ = t; } + + private: + explicit AddOnImpl(const string &type) : t_() { + SetType(type); + SetProperties(kExpanded); + } + + // Current file format version. + static constexpr int kFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + FST fst_; + std::shared_ptr t_; + + AddOnImpl &operator=(const AddOnImpl &) = delete; +}; + +template +constexpr int AddOnImpl::kFileVersion; + +template +constexpr int AddOnImpl::kMinFileVersion; + +} // namespace internal +} // namespace fst + +#endif // FST_ADD_ON_H_ diff --git a/projects/llm_framework/include/fst/arc-arena.h b/projects/llm_framework/include/fst/arc-arena.h new file mode 100644 index 00000000..13fe918a --- /dev/null +++ b/projects/llm_framework/include/fst/arc-arena.h @@ -0,0 +1,232 @@ +#ifndef FST_ARC_ARENA_H_ +#define FST_ARC_ARENA_H_ + +#include +#include +#include +#include +#include +#include + +namespace fst { + +// ArcArena is used for fast allocation of contiguous arrays of arcs. +// +// To create an arc array: +// for each state: +// for each arc: +// arena.PushArc(); +// // Commits these arcs and returns pointer to them. +// Arc *arcs = arena.GetArcs(); +// +// OR +// +// arena.DropArcs(); // Throws away current arcs, reuse the space. +// +// The arcs returned are guaranteed to be contiguous and the pointer returned +// will never be invalidated until the arena is cleared for reuse. +// +// The contents of the arena can be released with a call to arena.Clear() after +// which the arena will restart with an initial allocation capable of holding at +// least all of the arcs requested in the last usage before Clear() making +// subsequent uses of the Arena more efficient. +// +// The max_retained_size option can limit the amount of arc space requested on +// Clear() to avoid excess growth from intermittent high usage. +template +class ArcArena { + public: + explicit ArcArena(size_t block_size = 256, + size_t max_retained_size = 1e6) + : block_size_(block_size), + max_retained_size_(max_retained_size) { + blocks_.emplace_back(MakeSharedBlock(block_size_)); + first_block_size_ = block_size_; + total_size_ = block_size_; + arcs_ = blocks_.back().get(); + end_ = arcs_ + block_size_; + next_ = arcs_; + } + + ArcArena(const ArcArena& copy) + : arcs_(copy.arcs_), next_(copy.next_), end_(copy.end_), + block_size_(copy.block_size_), + first_block_size_(copy.first_block_size_), + total_size_(copy.total_size_), + max_retained_size_(copy.max_retained_size_), + blocks_(copy.blocks_) { + NewBlock(block_size_); + } + + void ReserveArcs(size_t n) { + if (next_ + n < end_) return; + NewBlock(n); + } + + void PushArc(const Arc& arc) { + if (next_ == end_) { + size_t length = next_ - arcs_; + NewBlock(length * 2); + } + *next_ = arc; + ++next_; + } + + const Arc* GetArcs() { + const auto *arcs = arcs_; + arcs_ = next_; + return arcs; + } + + void DropArcs() { next_ = arcs_; } + + size_t Size() { return total_size_; } + + void Clear() { + blocks_.resize(1); + if (total_size_ > first_block_size_) { + first_block_size_ = std::min(max_retained_size_, total_size_); + blocks_.back() = MakeSharedBlock(first_block_size_); + } + total_size_ = first_block_size_; + arcs_ = blocks_.back().get(); + end_ = arcs_ + first_block_size_; + next_ = arcs_; + } + + private: + // Allocates a new block with capacity of at least n or block_size, + // copying incomplete arc sequence from old block to new block. + void NewBlock(size_t n) { + const auto length = next_ - arcs_; + const auto new_block_size = std::max(n, block_size_); + total_size_ += new_block_size; + blocks_.emplace_back(MakeSharedBlock(new_block_size)); + std::copy(arcs_, next_, blocks_.back().get()); + arcs_ = blocks_.back().get(); + next_ = arcs_ + length; + end_ = arcs_ + new_block_size; + } + + std::shared_ptr MakeSharedBlock(size_t size) { + return std::shared_ptr(new Arc[size], std::default_delete()); + } + + Arc *arcs_; + Arc *next_; + const Arc *end_; + size_t block_size_; + size_t first_block_size_; + size_t total_size_; + size_t max_retained_size_; + std::list> blocks_; +}; + +// ArcArenaStateStore uses a resusable ArcArena to store arc arrays and does not +// require that the Expander call ReserveArcs first. +// +// TODO(tombagby): Make cache type configurable. +// TODO(tombagby): Provide ThreadLocal/Concurrent configuration. +template +class ArcArenaStateStore { + public: + using Arc = A; + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + ArcArenaStateStore() : arena_(64 * 1024) { + } + + class State { + public: + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return narcs_; } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + const Arc *Arcs() const { return arcs_; } + + int* MutableRefCount() const { return nullptr; } + + private: + State(Weight weight, int32 niepsilons, int32 noepsilons, int32 narcs, + const Arc *arcs) + : final_(std::move(weight)), + niepsilons_(niepsilons), + noepsilons_(noepsilons), + narcs_(narcs), + arcs_(arcs) {} + + Weight final_; + size_t niepsilons_; + size_t noepsilons_; + size_t narcs_; + const Arc *arcs_; + + friend class ArcArenaStateStore; + }; + + template + State *FindOrExpand(Expander &expander, StateId state_id) { // NOLINT + auto it = cache_.insert(std::pair(state_id, nullptr)); + if (!it.second) return it.first->second; + // Needs a new state. + StateBuilder builder(&arena_); + expander.Expand(state_id, &builder); + const auto arcs = arena_.GetArcs(); + size_t narcs = builder.narcs_; + size_t niepsilons = 0; + size_t noepsilons = 0; + for (size_t i = 0; i < narcs; ++i) { + if (arcs[i].ilabel == 0) ++niepsilons; + if (arcs[i].olabel == 0) ++noepsilons; + } + states_.emplace_back( + State(builder.final_, niepsilons, noepsilons, narcs, arcs)); + // Places it in the cache. + auto state = &states_.back(); + it.first->second = state; + return state; + } + + State *Find(StateId state_id) const { + auto it = cache_.find(state_id); + return (it == cache_.end()) ? nullptr : it->second; + } + + private: + class StateBuilder { + public: + explicit StateBuilder(ArcArena* arena) + : arena_(arena), final_(Weight::Zero()), narcs_(0) {} + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void ReserveArcs(size_t n) { arena_->ReserveArcs(n); } + + void AddArc(const Arc &arc) { + ++narcs_; + arena_->PushArc(arc); + } + + private: + friend class ArcArenaStateStore; + + ArcArena *arena_; + Weight final_; + size_t narcs_; + }; + + std::unordered_map cache_; + std::deque states_; + ArcArena arena_; +}; + +} // namespace fst + +#endif // FST_ARC_ARENA_H_ diff --git a/projects/llm_framework/include/fst/arc-map.h b/projects/llm_framework/include/fst/arc-map.h new file mode 100644 index 00000000..24db4911 --- /dev/null +++ b/projects/llm_framework/include/fst/arc-map.h @@ -0,0 +1,1285 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to map over/transform arcs e.g., change semirings or +// implement project/invert. Consider using when operation does +// not change the number of arcs (except possibly superfinal arcs). + +#ifndef FST_ARC_MAP_H_ +#define FST_ARC_MAP_H_ + +#include +#include +#include + +#include + +#include +#include + + +namespace fst { + +// Determines how final weights are mapped. +enum MapFinalAction { + // A final weight is mapped into a final weight. An error is raised if this + // is not possible. + MAP_NO_SUPERFINAL, + // A final weight is mapped to an arc to the superfinal state when the result + // cannot be represented as a final weight. The superfinal state will be + // added only if it is needed. + MAP_ALLOW_SUPERFINAL, + // A final weight is mapped to an arc to the superfinal state unless the + // result can be represented as a final weight of weight Zero(). The + // superfinal state is always added (if the input is not the empty FST). + MAP_REQUIRE_SUPERFINAL +}; + +// Determines how symbol tables are mapped. +enum MapSymbolsAction { + // Symbols should be cleared in the result by the map. + MAP_CLEAR_SYMBOLS, + // Symbols should be copied from the input FST by the map. + MAP_COPY_SYMBOLS, + // Symbols should not be modified in the result by the map itself. + // (They may set by the mapper). + MAP_NOOP_SYMBOLS +}; + +// The ArcMapper interfaces defines how arcs and final weights are mapped. +// This is useful for implementing operations that do not change the number of +// arcs (except possibly superfinal arcs). +// +// template +// class ArcMapper { +// public: +// using FromArc = A; +// using ToArc = B; +// +// // Maps an arc type FromArc to arc type ToArc. +// ToArc operator()(const FromArc &arc); +// +// // Specifies final action the mapper requires (see above). +// // The mapper will be passed final weights as arcs of the form +// // Arc(0, 0, weight, kNoStateId). +// MapFinalAction FinalAction() const; +// +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// +// // This specifies the known properties of an FST mapped by this mapper. It +// takes as argument the input FSTs's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// The ArcMap functions and classes below will use the FinalAction() +// method of the mapper to determine how to treat final weights, e.g., whether +// to add a superfinal state. They will use the Properties() method to set the +// result FST properties. +// +// We include a various map versions below. One dimension of variation is +// whether the mapping mutates its input, writes to a new result FST, or is an +// on-the-fly FST. Another dimension is how we pass the mapper. We allow passing +// the mapper by pointer for cases that we need to change the state of the +// user's mapper. This is the case with the EncodeMapper, which is reused +// during decoding. We also include map versions that pass the mapper by value +// or const reference when this suffices. + +// Maps an arc type A using a mapper function object C, passed +// by pointer. This version modifies its Fst input. +template +void ArcMap(MutableFst *fst, C *mapper) { + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetOutputSymbols(nullptr); + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kFstProperties, false); + const auto final_action = mapper->FinalAction(); + auto superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + const auto state = siter.Value(); + for (MutableArcIterator> aiter(fst, state); + !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + aiter.SetValue((*mapper)(arc)); + } + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + const auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; + fst->SetProperties(kError, kError); + } + fst->SetFinal(state, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (state != superfinal) { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + final_arc.nextstate = superfinal; + fst->AddArc(state, std::move(final_arc)); + fst->SetFinal(state, Weight::Zero()); + } else { + fst->SetFinal(state, final_arc.weight); + } + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + if (state != superfinal) { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + const auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != Weight::Zero()) { + fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + } + fst->SetFinal(state, Weight::Zero()); + } + break; + } + } + } + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + +// Maps an arc type A using a mapper function object C, passed by value. This +// version modifies its FST input. +template +void ArcMap(MutableFst *fst, C mapper) { + ArcMap(fst, &mapper); +} + +// Maps an arc type A to an arc type B using mapper function object C, +// passed by pointer. This version writes the mapped input FST to an +// output MutableFst. +template +void ArcMap(const Fst &ifst, MutableFst *ofst, C *mapper) { + using FromArc = A; + using StateId = typename FromArc::StateId; + ofst->DeleteStates(); + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetInputSymbols(ifst.InputSymbols()); + } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetOutputSymbols(nullptr); + } + const auto iprops = ifst.Properties(kCopyProperties, false); + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + const auto final_action = mapper->FinalAction(); + if (ifst.Properties(kExpanded, false)) { + ofst->ReserveStates( + CountStates(ifst) + (final_action == MAP_NO_SUPERFINAL ? 0 : 1)); + } + // Adds all states. + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + ofst->AddState(); + } + StateId superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (s == ifst.Start()) ofst->SetStart(s); + ofst->ReserveArcs( + s, ifst.NumArcs(s) + (final_action != MAP_NO_SUPERFINAL ? 1 : 0)); + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + ofst->AddArc(s, (*mapper)(aiter.Value())); + } + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; + ofst->SetProperties(kError, kError); + } + ofst->SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + final_arc.nextstate = superfinal; + ofst->AddArc(s, std::move(final_arc)); + ofst->SetFinal(s, B::Weight::Zero()); + } else { + ofst->SetFinal(s, final_arc.weight); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) { + ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + } + ofst->SetFinal(s, B::Weight::Zero()); + break; + } + } + } + const auto oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by value. This version writes the mapped input +// Fst to an output MutableFst. +template +void ArcMap(const Fst &ifst, MutableFst *ofst, C mapper) { + ArcMap(ifst, ofst, &mapper); +} + +struct ArcMapFstOptions : public CacheOptions { + // ArcMapFst default caching behaviour is to do no caching. Most mappers are + // cheap and therefore we save memory by not doing caching. + ArcMapFstOptions() : CacheOptions(true, 0) {} + + explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} +}; + +template +class ArcMapFst; + +namespace internal { + +// Implementation of delayed ArcMapFst. +template +class ArcMapFstImpl : public CacheImpl { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::EmplaceArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::PushArc; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + ArcMapFstImpl(const Fst &fst, const C &mapper, + const ArcMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(new C(mapper)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const Fst &fst, C *mapper, const ArcMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const ArcMapFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ~ArcMapFstImpl() override { + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) SetStart(FindOState(fst_->Start())); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc"; + SetProperties(kError, kError); + } + SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (s == superfinal_) { + SetFinal(s, Weight::One()); + } else { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel == 0 && final_arc.olabel == 0) { + SetFinal(s, final_arc.weight); + } else { + SetFinal(s, Weight::Zero()); + } + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero()); + break; + } + } + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + // Add exiting arcs. + if (s == superfinal_) { + SetArcs(s); + return; + } + for (ArcIterator> aiter(*fst_, FindIState(s)); !aiter.Done(); + aiter.Next()) { + auto aarc = aiter.Value(); + aarc.nextstate = FindOState(aarc.nextstate); + PushArc(s, (*mapper_)(aarc)); + } + + // Check for superfinal arcs. + if (!HasFinal(s) || Final(s) == Weight::Zero()) { + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: + break; + case MAP_ALLOW_SUPERFINAL: { + auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + if (superfinal_ == kNoStateId) superfinal_ = nstates_++; + final_arc.nextstate = superfinal_; + PushArc(s, std::move(final_arc)); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) { + EmplaceArc(s, final_arc.ilabel, final_arc.olabel, final_arc.weight, + superfinal_); + } + break; + } + } + } + SetArcs(s); + } + + private: + void Init() { + SetType("map"); + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetInputSymbols(fst_->InputSymbols()); + } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetInputSymbols(nullptr); + } + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetOutputSymbols(fst_->OutputSymbols()); + } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetOutputSymbols(nullptr); + } + if (fst_->Start() == kNoStateId) { + final_action_ = MAP_NO_SUPERFINAL; + SetProperties(kNullProperties); + } else { + final_action_ = mapper_->FinalAction(); + uint64 props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0; + } + } + + // Maps from output state to input state. + StateId FindIState(StateId s) { + if (superfinal_ == kNoStateId || s < superfinal_) { + return s; + } else { + return s - 1; + } + } + + // Maps from input state to output state. + StateId FindOState(StateId is) { + auto os = is; + if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os; + if (os >= nstates_) nstates_ = os + 1; + return os; + } + + std::unique_ptr> fst_; + C *mapper_; + const bool own_mapper_; + MapFinalAction final_action_; + StateId superfinal_; + StateId nstates_; +}; + +} // namespace internal + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed FST. +template +class ArcMapFst : public ImplToFst> { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::ArcMapFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + ArcMapFst(const Fst &fst, const C &mapper, const ArcMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + ArcMapFst(const Fst &fst, C *mapper, const ArcMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + ArcMapFst(const Fst &fst, const C &mapper) + : ImplToFst( + std::make_shared(fst, mapper, ArcMapFstOptions())) {} + + ArcMapFst(const Fst &fst, C *mapper) + : ImplToFst( + std::make_shared(fst, mapper, ArcMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + ArcMapFst(const ArcMapFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc. + ArcMapFst *Copy(bool safe = false) const override { + return new ArcMapFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + private: + ArcMapFst &operator=(const ArcMapFst &) = delete; +}; + +// Specialization for ArcMapFst. +// +// This may be derived from. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename B::StateId; + + explicit StateIterator(const ArcMapFst &fst) + : impl_(fst.GetImpl()), + siter_(*impl_->fst_), + s_(0), + superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) { + CheckSuperfinal(); + } + + bool Done() const final { return siter_.Done() && !superfinal_; } + + StateId Value() const final { return s_; } + + void Next() final { + ++s_; + if (!siter_.Done()) { + siter_.Next(); + CheckSuperfinal(); + } else if (superfinal_) { + superfinal_ = false; + } + } + + void Reset() final { + s_ = 0; + siter_.Reset(); + superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL; + CheckSuperfinal(); + } + + private: + void CheckSuperfinal() { + if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return; + if (!siter_.Done()) { + const auto final_arc = + (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true; + } + } + + const internal::ArcMapFstImpl *impl_; + StateIterator> siter_; + StateId s_; + bool superfinal_; // True if there is a superfinal state and not done. +}; + +// Specialization for ArcMapFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename A::StateId; + + ArcIterator(const ArcMapFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void ArcMapFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Utility Mappers. + +// Mapper that returns its input. +template +class IdentityArcMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { return arc; } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } +}; + +// Mapper that converts all input symbols to epsilon. +template +class InputEpsilonMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(0, arc.olabel, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kSetArcProperties) | kIEpsilons | kILabelSorted; + } +}; + +// Mapper that converts all output symbols to epsilon. +template +class OutputEpsilonMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kSetArcProperties) | kOEpsilons | kOLabelSorted; + } +}; + +// Mapper that returns its input with final states redirected to a single +// super-final state. +template +class SuperFinalMapper { + public: + using FromArc = A; + using ToArc = A; + using Label = typename FromArc::Label; + using Weight = typename FromArc::Weight;; + + // Arg allows setting super-final label. + explicit SuperFinalMapper(Label final_label = 0) + : final_label_(final_label) {} + + ToArc operator()(const FromArc &arc) const { + // Super-final arc. + if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) { + return ToArc(final_label_, final_label_, arc.weight, kNoStateId); + } else { + return arc; + } + } + + constexpr MapFinalAction FinalAction() const { + return MAP_REQUIRE_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + if (final_label_ == 0) { + return props & kAddSuperFinalProperties; + } else { + return props & kAddSuperFinalProperties & + kILabelInvariantProperties & kOLabelInvariantProperties; + } + } + + private: + Label final_label_; +}; + +// Mapper that leaves labels and nextstate unchanged and constructs a new weight +// from the underlying value of the arc weight. If no weight converter is +// explictly specified, requires that there is a WeightConvert class +// specialization that converts the weights. +template > +class WeightConvertMapper { + public: + using FromArc = A; + using ToArc = B; + using Converter = C; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + constexpr explicit WeightConvertMapper(const Converter &c = Converter()) + : convert_weight_(c) {} + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } + + private: + const Converter convert_weight_; +}; + +// Non-precision-changing weight conversions; consider using more efficient +// Cast method instead. + +using StdToLogMapper = WeightConvertMapper; + +using LogToStdMapper = WeightConvertMapper; + +// Precision-changing weight conversions. + +using StdToLog64Mapper = WeightConvertMapper; + +using LogToLog64Mapper = WeightConvertMapper; + +using Log64ToStdMapper = WeightConvertMapper; + +using Log64ToLogMapper = WeightConvertMapper; + +// Mapper from A to GallicArc. +template +class ToGallicMapper { + public: + using FromArc = A; + using ToArc = GallicArc; + + using SW = StringWeight; + using AW = typename FromArc::Weight; + using GW = typename ToArc::Weight; + + ToArc operator()(const FromArc &arc) const { + // Super-final arc. + if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) { + return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId); + // Super-non-final arc. + } else if (arc.nextstate == kNoStateId) { + return ToArc(0, 0, GW::Zero(), kNoStateId); + // Epsilon label. + } else if (arc.olabel == 0) { + return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight), + arc.nextstate); + // Regular label. + } else { + return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight), + arc.nextstate); + } + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return ProjectProperties(props, true) & kWeightInvariantProperties; + } +}; + +// Mapper from GallicArc to A. +template +class FromGallicMapper { + public: + using FromArc = GallicArc; + using ToArc = A; + + using Label = typename ToArc::Label; + using AW = typename ToArc::Weight; + using GW = typename FromArc::Weight; + + explicit FromGallicMapper(Label superfinal_label = 0) + : superfinal_label_(superfinal_label), error_(false) {} + + ToArc operator()(const FromArc &arc) const { + // 'Super-non-final' arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { + return A(arc.ilabel, 0, AW::Zero(), kNoStateId); + } + Label l = kNoLabel; + AW weight; + if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) { + FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight + << " for arc with ilabel = " << arc.ilabel + << ", olabel = " << arc.olabel + << ", nextstate = " << arc.nextstate; + error_ = true; + } + if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) { + return ToArc(superfinal_label_, l, weight, arc.nextstate); + } else { + return ToArc(arc.ilabel, l, weight, arc.nextstate); + } + } + + constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) outprops |= kError; + return outprops; + } + + private: + template + static bool Extract(const GallicWeight &gallic_weight, + typename A::Weight *weight, typename A::Label *label) { + using GWT = StringWeight; + const GWT &w1 = gallic_weight.Value1(); + const AW &w2 = gallic_weight.Value2(); + typename GWT::Iterator iter1(w1); + const Label l = w1.Size() == 1 ? iter1.Value() : 0; + if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false; + *label = l; + *weight = w2; + return true; + } + + static bool Extract(const GallicWeight &gallic_weight, + typename A::Weight *weight, typename A::Label *label) { + if (gallic_weight.Size() > 1) return false; + if (gallic_weight.Size() == 0) { + *label = 0; + *weight = A::Weight::Zero(); + return true; + } + return Extract(gallic_weight.Back(), weight, label); + } + + const Label superfinal_label_; + mutable bool error_; +}; + +// Mapper from GallicArc to A. +template +class GallicToNewSymbolsMapper { + public: + using FromArc = GallicArc; + using ToArc = A; + + using Label = typename ToArc::Label; + using StateId = typename ToArc::StateId; + using AW = typename ToArc::Weight; + using GW = typename FromArc::Weight; + using SW = StringWeight; + + explicit GallicToNewSymbolsMapper(MutableFst *fst) + : fst_(fst), + lmax_(0), + osymbols_(fst->OutputSymbols()), + isymbols_(nullptr), + error_(false) { + fst_->DeleteStates(); + state_ = fst_->AddState(); + fst_->SetStart(state_); + fst_->SetFinal(state_, AW::One()); + if (osymbols_) { + string name = osymbols_->Name() + "_from_gallic"; + fst_->SetInputSymbols(new SymbolTable(name)); + isymbols_ = fst_->MutableInputSymbols(); + const int64 zero = 0; + isymbols_->AddSymbol(osymbols_->Find(zero), 0); + } else { + fst_->SetInputSymbols(nullptr); + } + } + + ToArc operator()(const FromArc &arc) { + // Super-non-final arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { + return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId); + } + SW w1 = arc.weight.Value1(); + AW w2 = arc.weight.Value2(); + Label l; + if (w1.Size() == 0) { + l = 0; + } else { + auto insert_result = map_.insert(std::make_pair(w1, kNoLabel)); + if (!insert_result.second) { + l = insert_result.first->second; + } else { + l = ++lmax_; + insert_result.first->second = l; + StringWeightIterator iter1(w1); + StateId n; + string s; + for (size_t i = 0, p = state_; i < w1.Size(); + ++i, iter1.Next(), p = n) { + n = i == w1.Size() - 1 ? state_ : fst_->AddState(); + fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), AW::One(), n)); + if (isymbols_) { + if (i) s = s + "_"; + s = s + osymbols_->Find(iter1.Value()); + } + } + if (isymbols_) isymbols_->AddSymbol(s, l); + } + } + if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) { + FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l; + error_ = true; + } + return ToArc(arc.ilabel, l, w2, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) outprops |= kError; + return outprops; + } + + private: + class StringKey { + public: + size_t operator()(const SW &x) const { return x.Hash(); } + }; + + using Map = std::unordered_map; + + MutableFst *fst_; + Map map_; + Label lmax_; + StateId state_; + const SymbolTable *osymbols_; + SymbolTable *isymbols_; + mutable bool error_; +}; + +// TODO(kbg): Add common base class for those mappers which do nothing except +// mutate their weights. + +// Mapper to add a constant to all weights. +template +class PlusMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + constexpr explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {} + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const Weight weight_; +}; + +// Mapper to (right) multiply a constant to all weights. +template +class TimesMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + constexpr explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {} + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const Weight weight_; +}; + +// Mapper to take all weights to a constant power. The power argument is stored +// as a double, so if there is a floating-point power implementation for this +// weight type, it will take precedence. Otherwise, the power argument's 53 bits +// of integer precision will be implicitly converted to a size_t and the default +// power implementation (iterated multiplication) will be used instead. +template +class PowerMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + explicit PowerMapper(double power) : power_(power) {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const double power_; +}; + +// Mapper to reciprocate all non-Zero() weights. +template +class InvertWeightMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } +}; + +// Mapper to map all non-Zero() weights to One(). +template +class RmWeightMapper { + public: + using FromArc = A; + using ToArc = B; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, + arc.weight != FromWeight::Zero() ? + ToWeight::One() : ToWeight::Zero(), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kWeightInvariantProperties) | kUnweighted; + } +}; + +// Mapper to quantize all weights. +template +class QuantizeMapper { + public: + using FromArc = A; + using ToArc = B; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + QuantizeMapper() : delta_(kDelta) {} + + explicit QuantizeMapper(float d) : delta_(d) {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const float delta_; +}; + +// Mapper from A to B under the assumption: +// +// B::Weight = A::Weight::ReverseWeight +// B::Label == A::Label +// B::StateId == A::StateId +// +// The weight is reversed, while the label and nextstate are preserved. +template +class ReverseWeightMapper { + public: + using FromArc = A; + using ToArc = B; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_ARC_MAP_H_ diff --git a/projects/llm_framework/include/fst/arc.h b/projects/llm_framework/include/fst/arc.h new file mode 100644 index 00000000..651b11df --- /dev/null +++ b/projects/llm_framework/include/fst/arc.h @@ -0,0 +1,317 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Commonly used FST arc types. + +#ifndef FST_ARC_H_ +#define FST_ARC_H_ + +#include +#include +#include +#include + + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +struct ArcTpl { + public: + using Weight = W; + using Label = int; + using StateId = int; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ArcTpl() noexcept(std::is_nothrow_default_constructible::value) {} + + template + ArcTpl(Label ilabel, Label olabel, T &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string(Weight::Type() == "tropical" ? "standard" : Weight::Type()); + return *type; + } +}; + +using StdArc = ArcTpl; +using LogArc = ArcTpl; +using Log64Arc = ArcTpl; +using SignedLogArc = ArcTpl; +using SignedLog64Arc = ArcTpl; +using MinMaxArc = ArcTpl; + +// Arc with integer labels and state IDs and string weights. +template +struct StringArc { + public: + using Label = int; + using Weight = StringWeight; + using StateId = int; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + StringArc() = default; + + template + StringArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string( + S == STRING_LEFT ? "left_standard_string" + : (S == STRING_RIGHT ? "right_standard_string" + : "restricted_standard_string")); + return *type; + } +}; + +// Arc with label and state Id type the same as template arg and with +// weights over the Gallic semiring w.r.t the output labels and weights of A. +template +struct GallicArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = GallicWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + GallicArc() = default; + + template + GallicArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + explicit GallicArc(const Arc &arc) + : ilabel(arc.ilabel), olabel(arc.ilabel), weight(arc.olabel, arc.weight), + nextstate(arc.nextstate) {} + + static const string &Type() { + static const auto *const type = new string( + (G == GALLIC_LEFT + ? "left_gallic_" + : (G == GALLIC_RIGHT + ? "right_gallic_" + : (G == GALLIC_RESTRICT + ? "restricted_gallic_" + : (G == GALLIC_MIN ? "min_gallic_" : "gallic_")))) + + Arc::Type()); + return *type; + } +}; + +// Arc with the reverse of the weight found in its template arg. +template +struct ReverseArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using AWeight = typename Arc::Weight; + using Weight = typename AWeight::ReverseWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ReverseArc() = default; + + template + ReverseArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string("reverse_" + Arc::Type()); + return *type; + } +}; + +// Arc with integer labels and state IDs and lexicographic weights. +template +struct LexicographicArc { + using Label = int; + using StateId = int; + using Weight = LexicographicWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + LexicographicArc() = default; + + template + LexicographicArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const string *const type = new string(Weight::Type()); + return *type; + } +}; + +// Arc with integer labels and state IDs and product weights. +template +struct ProductArc { + using Label = int; + using StateId = int; + using Weight = ProductWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ProductArc() = default; + + template + ProductArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string(Weight::Type()); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// weights over the n-th Cartesian power of the weight type of the template +// argument. +template +struct PowerArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = PowerWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + PowerArc() = default; + + template + PowerArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string(Arc::Type() + "_^" + std::to_string(n)); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// weights over the arbitrary Cartesian power of the weight type. +template +struct SparsePowerArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::Label; + using Weight = SparsePowerWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + SparsePowerArc() = default; + + template + SparsePowerArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const string *const type = [] { + string type = Arc::Type() + "_^n"; + if (sizeof(K) != sizeof(uint32)) { + type += "_" + std::to_string(CHAR_BIT * sizeof(K)); + } + return new string(type); + }(); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// expectation weight over the first template argument's weight type and the +// second template argument. +template +struct ExpectationArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using X1 = typename Arc::Weight; + using Weight = ExpectationWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ExpectationArc() = default; + + template + ExpectationArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string("expectation_" + Arc::Type() + "_" + X2::Type()); + return *type; + } +}; + +} // namespace fst + +#endif // FST_ARC_H_ diff --git a/projects/llm_framework/include/fst/arcfilter.h b/projects/llm_framework/include/fst/arcfilter.h new file mode 100644 index 00000000..598e543f --- /dev/null +++ b/projects/llm_framework/include/fst/arcfilter.h @@ -0,0 +1,93 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function objects to restrict which arcs are traversed in an FST. + +#ifndef FST_ARCFILTER_H_ +#define FST_ARCFILTER_H_ + + +#include +#include + + +namespace fst { + +// True for all arcs. +template +class AnyArcFilter { + public: + bool operator()(const Arc &arc) const { return true; } +}; + +// True for (input/output) epsilon arcs. +template +class EpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { + return arc.ilabel == 0 && arc.olabel == 0; + } +}; + +// True for input epsilon arcs. +template +class InputEpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { return arc.ilabel == 0; } +}; + +// True for output epsilon arcs. +template +class OutputEpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { return arc.olabel == 0; } +}; + +// True if specified label matches (doesn't match) when keep_match is +// true (false). +template +class LabelArcFilter { + public: + using Label = typename Arc::Label; + + explicit LabelArcFilter(Label label, bool match_input = true, + bool keep_match = true) + : label_(label), match_input_(match_input), keep_match_(keep_match) {} + + bool operator()(const Arc &arc) const { + const bool match = (match_input_ ? arc.ilabel : arc.olabel) == label_; + return keep_match_ ? match : !match; + } + + private: + const Label label_; + const bool match_input_; + const bool keep_match_; +}; + +// True if specified labels match (don't match) when keep_match is true (false). +template +class MultiLabelArcFilter { + public: + using Label = typename Arc::Label; + + explicit MultiLabelArcFilter(bool match_input = true, bool keep_match = true) + : match_input_(match_input), keep_match_(keep_match) {} + + bool operator()(const Arc &arc) const { + const Label label = match_input_ ? arc.ilabel : arc.olabel; + const bool match = labels_.Find(label) != labels_.End(); + return keep_match_ ? match : !match; + } + + void AddLabel(Label label) { labels_.Insert(label); } + + private: + CompactSet labels_; + const bool match_input_; + const bool keep_match_; +}; + +} // namespace fst + +#endif // FST_ARCFILTER_H_ diff --git a/projects/llm_framework/include/fst/arcsort.h b/projects/llm_framework/include/fst/arcsort.h new file mode 100644 index 00000000..b5ab50e0 --- /dev/null +++ b/projects/llm_framework/include/fst/arcsort.h @@ -0,0 +1,211 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to sort arcs in an FST. + +#ifndef FST_ARCSORT_H_ +#define FST_ARCSORT_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +template +class ArcSortMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + constexpr ArcSortMapper(const Fst &fst, const Compare &comp) + : fst_(fst), comp_(comp), i_(0) {} + + // Allows updating Fst argument; pass only if changed. + ArcSortMapper(const ArcSortMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), comp_(mapper.comp_), i_(0) {} + + StateId Start() { return fst_.Start(); } + + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(s)); + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + std::sort(arcs_.begin(), arcs_.end(), comp_); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + uint64 Properties(uint64 props) const { return comp_.Properties(props); } + + private: + const Fst &fst_; + const Compare &comp_; + std::vector arcs_; + ssize_t i_; // current arc position + + ArcSortMapper &operator=(const ArcSortMapper &) = delete; +}; + +// Sorts the arcs in an FST according to function object 'comp' of type Compare. +// This version modifies its input. Comparison function objects ILabelCompare +// and OLabelCompare are provided by the library. In general, Compare must meet +// the requirements for a comparison function object (e.g., similar to those +// used by std::sort). It must also have a member Properties(uint64) that +// specifies the known properties of the sorted FST; it takes as argument the +// input FST's known properties before the sort. +// +// Complexity: +// +// - Time: O(v d log d) +// - Space: O(d) +// +// where v = # of states and d = maximum out-degree. +template +void ArcSort(MutableFst *fst, Compare comp) { + ArcSortMapper mapper(*fst, comp); + StateMap(fst, mapper); +} + +using ArcSortFstOptions = CacheOptions; + +// Sorts the arcs in an FST according to function object 'comp' of type Compare. +// This version is a delayed FST. Comparsion function objects ILabelCompare and +// OLabelCompare are provided by the library. In general, Compare must meet the +// requirements for a comparision function object (e.g., similar to those +// used by std::sort). It must also have a member Properties(uint64) that +// specifies the known properties of the sorted FST; it takes as argument the +// input FST's known properties. +// +// Complexity: +// +// - Time: O(v d log d) +// - Space: O(d) +// +// where v = # of states visited, d = maximum out-degree of states visited. +// Constant time and space to visit an input state is assumed and exclusive of +// caching. +template +class ArcSortFst : public StateMapFst> { + using StateMapFst>::GetImpl; + + public: + using StateId = typename Arc::StateId; + using Mapper = ArcSortMapper; + + ArcSortFst(const Fst &fst, const Compare &comp) + : StateMapFst(fst, + ArcSortMapper(fst, comp)) {} + + ArcSortFst(const Fst &fst, const Compare &comp, + const ArcSortFstOptions &opts) + : StateMapFst(fst, Mapper(fst, comp), opts) {} + + // See Fst<>::Copy() for doc. + ArcSortFst(const ArcSortFst &fst, bool safe = false) + : StateMapFst(fst, safe) {} + + // Gets a copy of this ArcSortFst. See Fst<>::Copy() for further doc. + ArcSortFst *Copy(bool safe = false) const override { + return new ArcSortFst(*this, safe); + } + + size_t NumArcs(StateId s) const override { + return GetImpl()->GetFst()->NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) const override { + return GetImpl()->GetFst()->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) const override { + return GetImpl()->GetFst()->NumOutputEpsilons(s); + } +}; + +// Specialization for ArcSortFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const ArcSortFst &fst) + : StateIterator>>(fst) { + } +}; + +// Specialization for ArcSortFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const ArcSortFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, + s) {} +}; + +// Compare class for comparing input labels of arcs. +template +class ILabelCompare { + public: + constexpr ILabelCompare() {} + + constexpr bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.ilabel < arc2.ilabel; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kILabelSorted | + (props & kAcceptor ? kOLabelSorted : 0); + } +}; + +// Compare class for comparing output labels of arcs. +template +class OLabelCompare { + public: + constexpr OLabelCompare() {} + + constexpr bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.olabel < arc2.olabel; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kOLabelSorted | + (props & kAcceptor ? kILabelSorted : 0); + } +}; + +// Useful aliases when using StdArc. + +template +using StdArcSortFst = ArcSortFst; + +using StdILabelCompare = ILabelCompare; + +using StdOLabelCompare = OLabelCompare; + +} // namespace fst + +#endif // FST_ARCSORT_H_ diff --git a/projects/llm_framework/include/fst/bi-table.h b/projects/llm_framework/include/fst/bi-table.h new file mode 100644 index 00000000..9651cfe8 --- /dev/null +++ b/projects/llm_framework/include/fst/bi-table.h @@ -0,0 +1,480 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for representing a bijective mapping between an arbitrary entry +// of type T and a signed integral ID. + +#ifndef FST_BI_TABLE_H_ +#define FST_BI_TABLE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Bitables model bijective mappings between entries of an arbitrary type T and +// an signed integral ID of type I. The IDs are allocated starting from 0 in +// order. +// +// template +// class BiTable { +// public: +// +// // Required constructors. +// BiTable(); +// +// // Looks up integer ID from entry. If it doesn't exist and insert +// / is true, adds it; otherwise, returns -1. +// I FindId(const T &entry, bool insert = true); +// +// // Looks up entry from integer ID. +// const T &FindEntry(I) const; +// +// // Returns number of stored entries. +// I Size() const; +// }; + +// An implementation using a hash map for the entry to ID mapping. H is the +// hash function and E is the equality function. If passed to the constructor, +// ownership is given to this class. +template > +class HashBiTable { + public: + // Reserves space for table_size elements. If passing H and E to the + // constructor, this class owns them. + explicit HashBiTable(size_t table_size = 0, H *h = nullptr, E *e = nullptr) : + hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()), + entry2id_(table_size, *hash_func_, *hash_equal_) { + if (table_size) id2entry_.reserve(table_size); + } + + HashBiTable(const HashBiTable &table) + : hash_func_(new H(*table.hash_func_)), + hash_equal_(new E(*table.hash_equal_)), + entry2id_(table.entry2id_.begin(), table.entry2id_.end(), + table.entry2id_.size(), *hash_func_, *hash_equal_), + id2entry_(table.id2entry_) {} + + I FindId(const T &entry, bool insert = true) { + if (!insert) { + const auto it = entry2id_.find(entry); + return it == entry2id_.end() ? -1 : it->second - 1; + } + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found; stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable. + void Clear() { + entry2id_.clear(); + id2entry_.clear(); + } + + private: + std::unique_ptr hash_func_; + std::unique_ptr hash_equal_; + std::unordered_map entry2id_; + std::vector id2entry_; +}; + +// Enables alternative hash set representations below. +enum HSType { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2, HS_FLAT = 3 }; + +// Default hash set is STL hash_set. +template +struct HashSet : public std::unordered_set> { + explicit HashSet(size_t n = 0, const H &h = H(), const E &e = E()) + : std::unordered_set>(n, h, e) {} + + void rehash(size_t n) {} +}; + +// An implementation using a hash set for the entry to ID mapping. The hash set +// holds keys which are either the ID or kCurrentKey. These keys can be mapped +// to entries either by looking up in the entry vector or, if kCurrentKey, in +// current_entry_. The hash and key equality functions map to entries first. H +// is the hash function and E is the equality function. If passed to the +// constructor, ownership is given to this class. +// TODO(rybach): remove support for (deprecated and unused) HS_DENSE, HS_SPARSE. +template , + HSType HS = HS_FLAT> +class CompactHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + // Reserves space for table_size elements. If passing H and E to the + // constructor, this class owns them. + explicit CompactHashBiTable(size_t table_size = 0, H *h = nullptr, + E *e = nullptr) : + hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()), + compact_hash_func_(*this), compact_hash_equal_(*this), + keys_(table_size, compact_hash_func_, compact_hash_equal_) { + if (table_size) id2entry_.reserve(table_size); + } + + CompactHashBiTable(const CompactHashBiTable &table) + : hash_func_(new H(*table.hash_func_)), + hash_equal_(new E(*table.hash_equal_)), + compact_hash_func_(*this), compact_hash_equal_(*this), + keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_), + id2entry_(table.id2entry_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + I FindId(const T &entry, bool insert = true) { + current_entry_ = &entry; + if (insert) { + auto result = keys_.insert(kCurrentKey); + if (!result.second) return *result.first; // Already exists. + // Overwrites kCurrentKey with a new key value; this is safe because it + // doesn't affect hashing or equality testing. + I key = id2entry_.size(); + const_cast(*result.first) = key; + id2entry_.push_back(entry); + return key; + } + const auto it = keys_.find(kCurrentKey); + return it == keys_.end() ? -1 : *it; + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + // Clears content; with argument, erases last n IDs. + void Clear(ssize_t n = -1) { + if (n < 0 || n >= id2entry_.size()) { // Clears completely. + keys_.clear(); + id2entry_.clear(); + } else if (n == id2entry_.size() - 1) { // Leaves only key 0. + const T entry = FindEntry(0); + keys_.clear(); + id2entry_.clear(); + FindId(entry, true); + } else { + while (n-- > 0) { + I key = id2entry_.size() - 1; + keys_.erase(key); + id2entry_.pop_back(); + } + keys_.rehash(0); + } + } + + private: + static_assert(std::is_signed::value, "I must be a signed type"); + // ... otherwise >= kCurrentKey comparisons as used below don't work. + // TODO(rybach): (1) remove kEmptyKey, kDeletedKey, (2) don't use >= for key + // comparison, (3) allow unsigned key types. + static constexpr I kCurrentKey = -1; + static constexpr I kEmptyKey = -2; + static constexpr I kDeletedKey = -3; + + class HashFunc { + public: + explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*ht_->hash_func_)(ht_->Key2Entry(k)); + } else { + return 0; + } + } + + private: + const CompactHashBiTable *ht_; + }; + + class HashEqual { + public: + explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 == k2) { + return true; + } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2)); + } else { + return false; + } + } + + private: + const CompactHashBiTable *ht_; + }; + + using KeyHashSet = HashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) { + return *current_entry_; + } else { + return id2entry_[k]; + } + } + + std::unique_ptr hash_func_; + std::unique_ptr hash_equal_; + HashFunc compact_hash_func_; + HashEqual compact_hash_equal_; + KeyHashSet keys_; + std::vector id2entry_; + const T *current_entry_; +}; + +template +constexpr I CompactHashBiTable::kCurrentKey; + +template +constexpr I CompactHashBiTable::kEmptyKey; + +template +constexpr I CompactHashBiTable::kDeletedKey; + +// An implementation using a vector for the entry to ID mapping. It is passed a +// function object FP that should fingerprint entries uniquely to an integer +// that can used as a vector index. Normally, VectorBiTable constructs the FP +// object. The user can instead pass in this object; in that case, VectorBiTable +// takes its ownership. +template +class VectorBiTable { + public: + // Reserves table_size cells of space. If passing FP argument to the + // constructor, this class owns it. + explicit VectorBiTable(FP *fp = nullptr, size_t table_size = 0) : + fp_(fp ? fp : new FP()) { + if (table_size) id2entry_.reserve(table_size); + } + + VectorBiTable(const VectorBiTable &table) + : fp_(new FP(*table.fp_)), fp2id_(table.fp2id_), + id2entry_(table.id2entry_) {} + + I FindId(const T &entry, bool insert = true) { + ssize_t fp = (*fp_)(entry); + if (fp >= fp2id_.size()) fp2id_.resize(fp + 1); + I &id_ref = fp2id_[fp]; + if (id_ref == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + const FP &Fingerprint() const { return *fp_; } + + private: + std::unique_ptr fp_; + std::vector fp2id_; + std::vector id2entry_; +}; + +// An implementation using a vector and a compact hash table. The selecting +// functor S returns true for entries to be hashed in the vector. The +// fingerprinting functor FP returns a unique fingerprint for each entry to be +// hashed in the vector (these need to be suitable for indexing in a vector). +// The hash functor H is used when hashing entry into the compact hash table. +// If passed to the constructor, ownership is given to this class. +template +class VectorHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + explicit VectorHashBiTable(S *s, FP *fp, H *h, size_t vector_size = 0, + size_t entry_size = 0) + : selector_(s), fp_(fp), h_(h), hash_func_(*this), hash_equal_(*this), + keys_(0, hash_func_, hash_equal_) { + if (vector_size) fp2id_.reserve(vector_size); + if (entry_size) id2entry_.reserve(entry_size); + } + + VectorHashBiTable(const VectorHashBiTable &table) + : selector_(new S(table.s_)), fp_(new FP(*table.fp_)), + h_(new H(*table.h_)), id2entry_(table.id2entry_), + fp2id_(table.fp2id_), hash_func_(*this), hash_equal_(*this), + keys_(table.keys_.size(), hash_func_, hash_equal_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + I FindId(const T &entry, bool insert = true) { + if ((*selector_)(entry)) { // Uses the vector if selector_(entry) == true. + uint64 fp = (*fp_)(entry); + if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0); + if (fp2id_[fp] == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + fp2id_[fp] = id2entry_.size(); + } else { + return -1; + } + } + return fp2id_[fp] - 1; // NB: assoc_value = ID + 1. + } else { // Uses the hash table otherwise. + current_entry_ = &entry; + const auto it = keys_.find(kCurrentKey); + if (it == keys_.end()) { + if (insert) { + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } + } else { + return *it; + } + } + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + const S &Selector() const { return *selector_; } + + const FP &Fingerprint() const { return *fp_; } + + const H &Hash() const { return *h_; } + + private: + static constexpr I kCurrentKey = -1; + static constexpr I kEmptyKey = -2; + + class HashFunc { + public: + explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*(ht_->h_))(ht_->Key2Entry(k)); + } else { + return 0; + } + } + + private: + const VectorHashBiTable *ht_; + }; + + class HashEqual { + public: + explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); + } else { + return k1 == k2; + } + } + + private: + const VectorHashBiTable *ht_; + }; + + using KeyHashSet = HashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) { + return *current_entry_; + } else { + return id2entry_[k]; + } + } + + std::unique_ptr selector_; // True if entry hashed into vector. + std::unique_ptr fp_; // Fingerprint used for hashing into vector. + std::unique_ptr h_; // Hash funcion used for hashing into hash_set. + + std::vector id2entry_; // Maps state IDs to entry. + std::vector fp2id_; // Maps entry fingerprints to IDs. + + // Compact implementation of the hash table mapping entries to state IDs + // using the hash function h_. + HashFunc hash_func_; + HashEqual hash_equal_; + KeyHashSet keys_; + const T *current_entry_; +}; + +template +constexpr I VectorHashBiTable::kCurrentKey; + +template +constexpr I VectorHashBiTable::kEmptyKey; + +// An implementation using a hash map for the entry to ID mapping. This version +// permits erasing of arbitrary states. The entry T must have == defined and +// its default constructor must produce a entry that will never be seen. F is +// the hash function. +template +class ErasableBiTable { + public: + ErasableBiTable() : first_(0) {} + + I FindId(const T &entry, bool insert = true) { + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size() + first_; + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s - first_]; } + + I Size() const { return id2entry_.size(); } + + void Erase(I s) { + auto &ref = id2entry_[s - first_]; + entry2id_.erase(ref); + ref = empty_entry_; + while (!id2entry_.empty() && id2entry_.front() == empty_entry_) { + id2entry_.pop_front(); + ++first_; + } + } + + private: + std::unordered_map entry2id_; + std::deque id2entry_; + const T empty_entry_; + I first_; // I of first element in the deque. +}; + +} // namespace fst + +#endif // FST_BI_TABLE_H_ diff --git a/projects/llm_framework/include/fst/cache.h b/projects/llm_framework/include/fst/cache.h new file mode 100644 index 00000000..13b7cf81 --- /dev/null +++ b/projects/llm_framework/include/fst/cache.h @@ -0,0 +1,1327 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation that caches FST elements of a delayed computation. + +#ifndef FST_CACHE_H_ +#define FST_CACHE_H_ + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +DECLARE_bool(fst_default_cache_gc); +DECLARE_int64(fst_default_cache_gc_limit); + +namespace fst { + +// Options for controlling caching behavior; higher level than CacheImplOptions. +struct CacheOptions { + bool gc; // Enables GC. + size_t gc_limit; // Number of bytes allowed before GC. + + explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc, + size_t gc_limit = FLAGS_fst_default_cache_gc_limit) + : gc(gc), gc_limit(gc_limit) {} +}; + +// Options for controlling caching behavior, at a lower level than +// CacheOptions; templated on the cache store and allows passing the store. +template +struct CacheImplOptions { + bool gc; // Enables GC. + size_t gc_limit; // Number of bytes allowed before GC. + CacheStore *store; // Cache store. + bool own_store; // Should CacheImpl takes ownership of the store? + + explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc, + size_t gc_limit = FLAGS_fst_default_cache_gc_limit, + CacheStore *store = nullptr) + : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {} + + explicit CacheImplOptions(const CacheOptions &opts) + : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {} +}; + +// Cache flags. +constexpr uint32 kCacheFinal = 0x0001; // Final weight has been cached. +constexpr uint32 kCacheArcs = 0x0002; // Arcs have been cached. +constexpr uint32 kCacheInit = 0x0004; // Initialized by GC. +constexpr uint32 kCacheRecent = 0x0008; // Visited since GC. +constexpr uint32 kCacheFlags = + kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent; + +// Cache state, with arcs stored in a per-state std::vector. +template > +class CacheState { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ArcAllocator = M; + using StateAllocator = + typename ArcAllocator::template rebind>::other; + + // Provides STL allocator for arcs. + explicit CacheState(const ArcAllocator &alloc) + : final_(Weight::Zero()), + niepsilons_(0), + noepsilons_(0), + arcs_(alloc), + flags_(0), + ref_count_(0) {} + + CacheState(const CacheState &state, const ArcAllocator &alloc) + : final_(state.Final()), + niepsilons_(state.NumInputEpsilons()), + noepsilons_(state.NumOutputEpsilons()), + arcs_(state.arcs_.begin(), state.arcs_.end(), alloc), + flags_(state.Flags()), + ref_count_(0) {} + + void Reset() { + final_ = Weight::Zero(); + niepsilons_ = 0; + noepsilons_ = 0; + ref_count_ = 0; + flags_ = 0; + arcs_.clear(); + } + + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return arcs_.size(); } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + // Used by the ArcIterator> efficient implementation. + const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + // Accesses flags; used by the caller. + uint32 Flags() const { return flags_; } + + // Accesses ref count; used by the caller. + int RefCount() const { return ref_count_; } + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void ReserveArcs(size_t n) { arcs_.reserve(n); } + + // Adds one arc at a time with all needed book-keeping; use PushArc and + // SetArcs for a more efficient alternative. + void AddArc(const Arc &arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(arc); + } + + void AddArc(Arc &&arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(std::move(arc)); + } + + // Adds one arc at a time with delayed book-keeping; finalize with SetArcs(). + void PushArc(const Arc &arc) { arcs_.push_back(arc); } + + void PushArc(Arc &&arc) { arcs_.push_back(std::move(arc)); } + + // Adds one arc at a time with delayed book-keeping; finalize with SetArcs(). + template + void EmplaceArc(T &&... ctor_args) { + arcs_.emplace_back(std::forward(ctor_args)...); + } + + // Finalizes arcs book-keeping; call only once. + void SetArcs() { + for (const auto &arc : arcs_) { + IncrementNumEpsilons(arc); + } + } + + // Modifies nth arc. + void SetArc(const Arc &arc, size_t n) { + if (arcs_[n].ilabel == 0) --niepsilons_; + if (arcs_[n].olabel == 0) --noepsilons_; + IncrementNumEpsilons(arc); + arcs_[n] = arc; + } + + // Deletes all arcs. + void DeleteArcs() { + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + void DeleteArcs(size_t n) { + for (size_t i = 0; i < n; ++i) { + if (arcs_.back().ilabel == 0) --niepsilons_; + if (arcs_.back().olabel == 0) --noepsilons_; + arcs_.pop_back(); + } + } + + // Sets status flags; used by the caller. + void SetFlags(uint32 flags, uint32 mask) const { + flags_ &= ~mask; + flags_ |= flags; + } + + // Mutates reference counts; used by the caller. + + int IncrRefCount() const { return ++ref_count_; } + + int DecrRefCount() const { return --ref_count_; } + + // Used by the ArcIterator> efficient implementation. + int *MutableRefCount() const { return &ref_count_; } + + // Used for state class allocation. + void *operator new(size_t size, StateAllocator *alloc) { + return alloc->allocate(1); + } + + // For state destruction and memory freeing. + static void Destroy(CacheState *state, StateAllocator *alloc) { + if (state) { + state->~CacheState(); + alloc->deallocate(state, 1); + } + } + + private: + // Update the number of epsilons as a result of having added an arc. + void IncrementNumEpsilons(const Arc &arc) { + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + } + + Weight final_; // Final weight. + size_t niepsilons_; // # of input epsilons. + size_t noepsilons_; // # of output epsilons. + std::vector arcs_; // Arcs representation. + mutable uint32 flags_; + mutable int ref_count_; // If 0, available for GC. +}; + +// Cache store, allocating and storing states, providing a mapping from state +// IDs to cached states, and an iterator over these states. The state template +// argument must implement the CacheState interface. The state for a StateId s +// is constructed when requested by GetMutableState(s) if it is not yet stored. +// Initially, a state has a reference count of zero, but the user may increment +// or decrement this to control the time of destruction. In particular, a state +// is destroyed when: +// +// 1. This instance is destroyed, or +// 2. Clear() or Delete() is called, or +// 3. Possibly (implementation-dependently) when: +// - Garbage collection is enabled (as defined by opts.gc), +// - The cache store size exceeds the limits (as defined by opts.gc_limits), +// - The state's reference count is zero, and +// - The state is not the most recently requested state. +// +// template +// class CacheStore { +// public: +// using State = S; +// using Arc = typename State::Arc; +// using StateId = typename Arc::StateId; +// +// // Required constructors/assignment operators. +// explicit CacheStore(const CacheOptions &opts); +// +// // Returns nullptr if state is not stored. +// const State *GetState(StateId s); +// +// // Creates state if state is not stored. +// State *GetMutableState(StateId s); +// +// // Similar to State::AddArc() but updates cache store book-keeping. +// void AddArc(State *state, const Arc &arc); +// +// // Similar to State::SetArcs() but updates cache store book-keeping; call +// // only once. +// void SetArcs(State *state); +// +// // Similar to State::DeleteArcs() but updates cache store book-keeping. +// +// void DeleteArcs(State *state); +// +// void DeleteArcs(State *state, size_t n); +// +// // Deletes all cached states. +// void Clear(); +// +// // Number of cached states. +// StateId CountStates(); +// +// // Iterates over cached states (in an arbitrary order); only needed if +// // opts.gc is true. +// bool Done() const; // End of iteration. +// StateId Value() const; // Current state. +// void Next(); // Advances to next state (when !Done). +// void Reset(); // Returns to initial condition. +// void Delete(); // Deletes current state and advances to next. +// }; + +// Container cache stores. + +// This class uses a vector of pointers to states to store cached states. +template +class VectorCacheStore { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using StateList = std::list>; + + // Required constructors/assignment operators. + explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) { + Clear(); + Reset(); + } + + VectorCacheStore(const VectorCacheStore &store) + : cache_gc_(store.cache_gc_) { + CopyStates(store); + Reset(); + } + + ~VectorCacheStore() { Clear(); } + + VectorCacheStore &operator=(const VectorCacheStore &store) { + if (this != &store) { + CopyStates(store); + Reset(); + } + return *this; + } + + bool InBounds(StateId s) const { + return s < static_cast(state_vec_.size()); + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + return InBounds(s) ? state_vec_[s] : nullptr; + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + State *state = nullptr; + if (InBounds(s)) { + state = state_vec_[s]; + } else { + state_vec_.resize(s + 1, nullptr); + } + if (!state) { + state = new (&state_alloc_) State(arc_alloc_); + state_vec_[s] = state; + if (cache_gc_) state_list_.push_back(s); + } + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping + void AddArc(State *state, const Arc &arc) { state->AddArc(arc); } + + // Similar to State::SetArcs() but updates cache store book-keeping; call + // only once. + void SetArcs(State *state) { state->SetArcs(); } + + // Deletes all arcs. + void DeleteArcs(State *state) { state->DeleteArcs(); } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); } + + // Deletes all cached states. + void Clear() { + for (State *s : state_vec_) { + State::Destroy(s, &state_alloc_); + } + state_vec_.clear(); + state_list_.clear(); + } + + StateId CountStates() const { + return std::count_if(state_vec_.begin(), state_vec_.end(), + [](const State *s) { return s != nullptr; }); + } + + // Iterates over cached states (in an arbitrary order); only works if GC is + // enabled (o.w. avoiding state_list_ overhead). + bool Done() const { return iter_ == state_list_.end(); } + + StateId Value() const { return *iter_; } + + void Next() { ++iter_; } + + void Reset() { iter_ = state_list_.begin(); } + + // Deletes current state and advances to next. + void Delete() { + State::Destroy(state_vec_[*iter_], &state_alloc_); + state_vec_[*iter_] = nullptr; + state_list_.erase(iter_++); + } + + private: + void CopyStates(const VectorCacheStore &store) { + Clear(); + state_vec_.reserve(store.state_vec_.size()); + for (size_t s = 0; s < store.state_vec_.size(); ++s) { + State *state = nullptr; + const auto *store_state = store.state_vec_[s]; + if (store_state) { + state = new (&state_alloc_) State(*store_state, arc_alloc_); + if (cache_gc_) state_list_.push_back(s); + } + state_vec_.push_back(state); + } + } + + bool cache_gc_; // Supports iteration when true. + std::vector state_vec_; // Vector of states (or null). + StateList state_list_; // List of states. + typename StateList::iterator iter_; // State list iterator. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// This class uses a hash map from state IDs to pointers to cached states. +template +class HashCacheStore { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + using StateMap = + std::unordered_map, + std::equal_to, + PoolAllocator>>; + + // Required constructors/assignment operators. + explicit HashCacheStore(const CacheOptions &opts) { + Clear(); + Reset(); + } + + HashCacheStore(const HashCacheStore &store) { + CopyStates(store); + Reset(); + } + + ~HashCacheStore() { Clear(); } + + HashCacheStore &operator=(const HashCacheStore &store) { + if (this != &store) { + CopyStates(store); + Reset(); + } + return *this; + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + const auto it = state_map_.find(s); + return it != state_map_.end() ? it->second : nullptr; + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + auto *&state = state_map_[s]; + if (!state) state = new (&state_alloc_) State(arc_alloc_); + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { state->AddArc(arc); } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { state->SetArcs(); } + + // Deletes all arcs. + void DeleteArcs(State *state) { state->DeleteArcs(); } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); } + + // Deletes all cached states. + void Clear() { + for (auto it = state_map_.begin(); it != state_map_.end(); ++it) { + State::Destroy(it->second, &state_alloc_); + } + state_map_.clear(); + } + + StateId CountStates() const { return state_map_.size(); } + + // Iterates over cached states (in an arbitrary order). + bool Done() const { return iter_ == state_map_.end(); } + + StateId Value() const { return iter_->first; } + + void Next() { ++iter_; } + + void Reset() { iter_ = state_map_.begin(); } + + // Deletes current state and advances to next. + void Delete() { + State::Destroy(iter_->second, &state_alloc_); + state_map_.erase(iter_++); + } + + private: + void CopyStates(const HashCacheStore &store) { + Clear(); + for (auto it = store.state_map_.begin(); it != store.state_map_.end(); + ++it) { + state_map_[it->first] = + new (&state_alloc_) State(*it->second, arc_alloc_); + } + } + + StateMap state_map_; // Map from state ID to state. + typename StateMap::iterator iter_; // State map iterator. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// Garbage-colllection cache stores. + +// This class implements a simple garbage collection scheme when +// 'opts.gc_limit = 0'. In particular, the first cached state is reused for each +// new state so long as the reference count is zero on the to-be-reused state. +// Otherwise, the full underlying store is used. The caller can increment the +// reference count to inhibit the GC of in-use states (e.g., in an ArcIterator). +// +// The typical use case for this optimization is when a single pass over a +// cached +// FST is performed with only one-state expanded at a time. +template +class FirstCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + // Required constructors/assignment operators. + explicit FirstCacheStore(const CacheOptions &opts) + : store_(opts), + cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically. + cache_first_state_id_(kNoStateId), + cache_first_state_(nullptr) {} + + FirstCacheStore(const FirstCacheStore &store) + : store_(store.store_), + cache_gc_(store.cache_gc_), + cache_first_state_id_(store.cache_first_state_id_), + cache_first_state_(store.cache_first_state_id_ != kNoStateId + ? store_.GetMutableState(0) + : nullptr) {} + + FirstCacheStore &operator=( + const FirstCacheStore &store) { + if (this != &store) { + store_ = store.store_; + cache_gc_ = store.cache_gc_; + cache_first_state_id_ = store.cache_first_state_id_; + cache_first_state_ = store.cache_first_state_id_ != kNoStateId + ? store_.GetMutableState(0) + : nullptr; + } + return *this; + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + // store_ state 0 may hold first cached state; the rest are shifted by 1. + return s == cache_first_state_id_ ? cache_first_state_ + : store_.GetState(s + 1); + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + // store_ state 0 used to hold first cached state; the rest are shifted by + // 1. + if (cache_first_state_id_ == s) { + return cache_first_state_; // Request for first cached state. + } + if (cache_gc_) { + if (cache_first_state_id_ == kNoStateId) { + cache_first_state_id_ = s; // Sets first cached state. + cache_first_state_ = store_.GetMutableState(0); + cache_first_state_->SetFlags(kCacheInit, kCacheInit); + cache_first_state_->ReserveArcs(2 * kAllocSize); + return cache_first_state_; + } else if (cache_first_state_->RefCount() == 0) { + cache_first_state_id_ = s; // Updates first cached state. + cache_first_state_->Reset(); + cache_first_state_->SetFlags(kCacheInit, kCacheInit); + return cache_first_state_; + } else { // Keeps first cached state. + cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit. + cache_gc_ = false; // Disables GC. + } + } + auto *state = store_.GetMutableState(s + 1); + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { store_.SetArcs(state); } + + // Deletes all arcs + void DeleteArcs(State *state) { store_.DeleteArcs(state); } + + // Deletes some arcs + void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); } + + // Deletes all cached states + void Clear() { + store_.Clear(); + cache_first_state_id_ = kNoStateId; + cache_first_state_ = nullptr; + } + + StateId CountStates() const { return store_.CountStates(); } + + // Iterates over cached states (in an arbitrary order). Only needed if GC is + // enabled. + bool Done() const { return store_.Done(); } + + StateId Value() const { + // store_ state 0 may hold first cached state; rest shifted + 1. + const auto s = store_.Value(); + return s ? s - 1 : cache_first_state_id_; + } + + void Next() { store_.Next(); } + + void Reset() { store_.Reset(); } + + // Deletes current state and advances to next. + void Delete() { + if (Value() == cache_first_state_id_) { + cache_first_state_id_ = kNoStateId; + cache_first_state_ = nullptr; + } + store_.Delete(); + } + + private: + CacheStore store_; // Underlying store. + bool cache_gc_; // GC enabled. + StateId cache_first_state_id_; // First cached state ID. + State *cache_first_state_; // First cached state. +}; + +// This class implements mark-sweep garbage collection on an underlying cache +// store. If GC is enabled, garbage collection of states is performed in a +// rough approximation of LRU order once when 'gc_limit' bytes is reached. The +// caller can increment the reference count to inhibit the GC of in-use state +// (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows +// the caller to trade-off time vs. space. +template +class GCCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + // Required constructors/assignment operators. + explicit GCCacheStore(const CacheOptions &opts) + : store_(opts), + cache_gc_request_(opts.gc), + cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit + : kMinCacheLimit), + cache_gc_(false), + cache_size_(0) {} + + // Returns 0 if state is not stored. + const State *GetState(StateId s) const { return store_.GetState(s); } + + // Creates state if state is not stored + State *GetMutableState(StateId s) { + auto *state = store_.GetMutableState(s); + if (cache_gc_request_ && !(state->Flags() & kCacheInit)) { + state->SetFlags(kCacheInit, kCacheInit); + cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc); + // GC is enabled once an uninited state (from underlying store) is seen. + cache_gc_ = true; + if (cache_size_ > cache_limit_) GC(state, false); + } + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { + store_.AddArc(state, arc); + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ += sizeof(Arc); + if (cache_size_ > cache_limit_) GC(state, false); + } + } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { + store_.SetArcs(state); + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ += state->NumArcs() * sizeof(Arc); + if (cache_size_ > cache_limit_) GC(state, false); + } + } + + // Deletes all arcs. + void DeleteArcs(State *state) { + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ -= state->NumArcs() * sizeof(Arc); + } + store_.DeleteArcs(state); + } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ -= n * sizeof(Arc); + } + store_.DeleteArcs(state, n); + } + + // Deletes all cached states. + void Clear() { + store_.Clear(); + cache_size_ = 0; + } + + StateId CountStates() const { return store_.CountStates(); } + + // Iterates over cached states (in an arbitrary order); only needed if GC is + // enabled. + bool Done() const { return store_.Done(); } + + StateId Value() const { return store_.Value(); } + + void Next() { store_.Next(); } + + void Reset() { store_.Reset(); } + + // Deletes current state and advances to next. + void Delete() { + if (cache_gc_) { + const auto *state = store_.GetState(Value()); + if (state->Flags() & kCacheInit) { + cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc); + } + } + store_.Delete(); + } + + // Removes from the cache store (not referenced-counted and not the current) + // states that have not been accessed since the last GC until at most + // cache_fraction * cache_limit_ bytes are cached. If that fails to free + // enough, attempts to uncaching recently visited states as well. If still + // unable to free enough memory, then widens cache_limit_. + void GC(const State *current, bool free_recent, float cache_fraction = 0.666); + + // Returns the current cache size in bytes or 0 if GC is disabled. + size_t CacheSize() const { return cache_size_; } + + // Returns the cache limit in bytes. + size_t CacheLimit() const { return cache_limit_; } + + private: + static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit. + + CacheStore store_; // Underlying store. + bool cache_gc_request_; // GC requested but possibly not yet enabled. + size_t cache_limit_; // Number of bytes allowed before GC. + bool cache_gc_; // GC enabled + size_t cache_size_; // Number of bytes cached. +}; + +template +void GCCacheStore::GC(const State *current, bool free_recent, + float cache_fraction) { + if (!cache_gc_) return; + VLOG(2) << "GCCacheStore: Enter GC: object = " + << "(" << this << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; + size_t cache_target = cache_fraction * cache_limit_; + store_.Reset(); + while (!store_.Done()) { + auto *state = store_.GetMutableState(store_.Value()); + if (cache_size_ > cache_target && state->RefCount() == 0 && + (free_recent || !(state->Flags() & kCacheRecent)) && state != current) { + if (state->Flags() & kCacheInit) { + size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc); + if (size < cache_size_) { + cache_size_ -= size; + } + } + store_.Delete(); + } else { + state->SetFlags(0, kCacheRecent); + store_.Next(); + } + } + if (!free_recent && cache_size_ > cache_target) { // Recurses on recent. + GC(current, true, cache_fraction); + } else if (cache_target > 0) { // Widens cache limit. + while (cache_size_ > cache_target) { + cache_limit_ *= 2; + cache_target *= 2; + } + } else if (cache_size_ > 0) { + FSTERROR() << "GCCacheStore:GC: Unable to free all cached states"; + } + VLOG(2) << "GCCacheStore: Exit GC: object = " + << "(" << this << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; +} + +template +constexpr size_t GCCacheStore::kMinCacheLimit; + +// This class is the default cache state and store used by CacheBaseImpl. +// It uses VectorCacheStore for storage decorated by FirstCacheStore +// and GCCacheStore to do (optional) garbage collection. +template +class DefaultCacheStore + : public GCCacheStore>>> { + public: + explicit DefaultCacheStore(const CacheOptions &opts) + : GCCacheStore>>>(opts) { + } +}; + +namespace internal { + +// This class is used to cache FST elements stored in states of type State +// (see CacheState) with the flags used to indicate what has been cached. Use +// HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(), +// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you +// must set the final weight even if the state is non-final to mark it as +// cached. The state storage method and any garbage collection policy are +// determined by the cache store. If the store is passed in with the options, +// CacheBaseImpl takes ownership. +template > +class CacheBaseImpl : public FstImpl { + public: + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = CacheStore; + + using FstImpl::Type; + using FstImpl::Properties; + + explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions()) + : has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(opts.gc), + cache_limit_(opts.gc_limit), + cache_store_(new CacheStore(opts)), + new_cache_store_(true), + own_cache_store_(true) {} + + explicit CacheBaseImpl(const CacheImplOptions &opts) + : has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(opts.gc), + cache_limit_(opts.gc_limit), + cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions( + opts.gc, opts.gc_limit))), + new_cache_store_(!opts.store), + own_cache_store_(opts.store ? opts.own_store : true) {} + + // Preserve gc parameters. If preserve_cache is true, also preserves + // cache data. + CacheBaseImpl(const CacheBaseImpl &impl, + bool preserve_cache = false) + : FstImpl(), + has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(impl.cache_gc_), + cache_limit_(impl.cache_limit_), + cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))), + new_cache_store_(impl.new_cache_store_ || !preserve_cache), + own_cache_store_(true) { + if (preserve_cache) { + *cache_store_ = *impl.cache_store_; + has_start_ = impl.has_start_; + cache_start_ = impl.cache_start_; + nknown_states_ = impl.nknown_states_; + expanded_states_ = impl.expanded_states_; + min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; + max_expanded_state_id_ = impl.max_expanded_state_id_; + } + } + + ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; } + + void SetStart(StateId s) { + cache_start_ = s; + has_start_ = true; + if (s >= nknown_states_) nknown_states_ = s + 1; + } + + void SetFinal(StateId s, Weight weight) { + auto *state = cache_store_->GetMutableState(s); + state->SetFinal(std::move(weight)); + static constexpr auto flags = kCacheFinal | kCacheRecent; + state->SetFlags(flags, flags); + } + +// Disabled to ensure PushArc not AddArc is used in existing code +// TODO(sorenj): re-enable for backing store +#if 0 + // AddArc adds a single arc to a state and does incremental cache + // book-keeping. For efficiency, prefer PushArc and SetArcs below + // when possible. + void AddArc(StateId s, const Arc &arc) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->AddArc(state, arc); + if (arc.nextstate >= nknown_states_) + nknown_states_ = arc.nextstate + 1; + SetExpandedState(s); + static constexpr auto flags = kCacheArcs | kCacheRecent; + state->SetFlags(flags, flags); + } +#endif + + // Adds a single arc to a state but delays cache book-keeping. SetArcs must + // be called when all PushArc and EmplaceArc calls at a state are complete. + // Do not mix with calls to AddArc. + void PushArc(StateId s, const Arc &arc) { + auto *state = cache_store_->GetMutableState(s); + state->PushArc(arc); + } + + void PushArc(StateId s, Arc &&arc) { + auto *state = cache_store_->GetMutableState(s); + state->PushArc(std::move(arc)); + } + + // Adds a single arc to a state but delays cache book-keeping. SetArcs must + // be called when all PushArc and EmplaceArc calls at a state are complete. + // Do not mix with calls to AddArc. + template + void EmplaceArc(StateId s, T &&... ctor_args) { + auto *state = cache_store_->GetMutableState(s); + state->EmplaceArc(std::forward(ctor_args)...); + } + + // Marks arcs of a state as cached and does cache book-keeping after all + // calls to PushArc have been completed. Do not mix with calls to AddArc. + void SetArcs(StateId s) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->SetArcs(state); + const auto narcs = state->NumArcs(); + for (size_t a = 0; a < narcs; ++a) { + const auto &arc = state->GetArc(a); + if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1; + } + SetExpandedState(s); + static constexpr auto flags = kCacheArcs | kCacheRecent; + state->SetFlags(flags, flags); + } + + void ReserveArcs(StateId s, size_t n) { + auto *state = cache_store_->GetMutableState(s); + state->ReserveArcs(n); + } + + void DeleteArcs(StateId s) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->DeleteArcs(state); + } + + void DeleteArcs(StateId s, size_t n) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->DeleteArcs(state, n); + } + + void Clear() { + nknown_states_ = 0; + min_unexpanded_state_id_ = 0; + max_expanded_state_id_ = -1; + has_start_ = false; + cache_start_ = kNoStateId; + cache_store_->Clear(); + } + + // Is the start state cached? + bool HasStart() const { + if (!has_start_ && Properties(kError)) has_start_ = true; + return has_start_; + } + + // Is the final weight of the state cached? + bool HasFinal(StateId s) const { + const auto *state = cache_store_->GetState(s); + if (state && state->Flags() & kCacheFinal) { + state->SetFlags(kCacheRecent, kCacheRecent); + return true; + } else { + return false; + } + } + + // Are arcs of the state cached? + bool HasArcs(StateId s) const { + const auto *state = cache_store_->GetState(s); + if (state && state->Flags() & kCacheArcs) { + state->SetFlags(kCacheRecent, kCacheRecent); + return true; + } else { + return false; + } + } + + StateId Start() const { return cache_start_; } + + Weight Final(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->Final(); + } + + size_t NumArcs(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumArcs(); + } + + size_t NumInputEpsilons(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumInputEpsilons(); + } + + size_t NumOutputEpsilons(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumOutputEpsilons(); + } + + // Provides information needed for generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + const auto *state = cache_store_->GetState(s); + data->base = nullptr; + data->narcs = state->NumArcs(); + data->arcs = state->Arcs(); + data->ref_count = state->MutableRefCount(); + state->IncrRefCount(); + } + + // Number of known states. + StateId NumKnownStates() const { return nknown_states_; } + + // Updates number of known states, taking into account the passed state ID. + void UpdateNumKnownStates(StateId s) { + if (s >= nknown_states_) nknown_states_ = s + 1; + } + + // Finds the mininum never-expanded state ID. + StateId MinUnexpandedState() const { + while (min_unexpanded_state_id_ <= max_expanded_state_id_ && + ExpandedState(min_unexpanded_state_id_)) { + ++min_unexpanded_state_id_; + } + return min_unexpanded_state_id_; + } + + // Returns maximum ever-expanded state ID. + StateId MaxExpandedState() const { return max_expanded_state_id_; } + + void SetExpandedState(StateId s) { + if (s > max_expanded_state_id_) max_expanded_state_id_ = s; + if (s < min_unexpanded_state_id_) return; + if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_; + if (cache_gc_ || cache_limit_ == 0) { + if (expanded_states_.size() <= static_cast(s)) + expanded_states_.resize(s + 1, false); + expanded_states_[s] = true; + } + } + + bool ExpandedState(StateId s) const { + if (cache_gc_ || cache_limit_ == 0) { + return expanded_states_[s]; + } else if (new_cache_store_) { + return cache_store_->GetState(s) != nullptr; + } else { + // If the cache was not created by this class, then the cached state needs + // to be inspected to update nknown_states_. + return false; + } + } + + const CacheStore *GetCacheStore() const { return cache_store_; } + + CacheStore *GetCacheStore() { return cache_store_; } + + // Caching on/off switch, limit and size accessors. + + bool GetCacheGc() const { return cache_gc_; } + + size_t GetCacheLimit() const { return cache_limit_; } + + private: + mutable bool has_start_; // Is the start state cached? + StateId cache_start_; // ID of start state. + StateId nknown_states_; // Number of known states. + std::vector expanded_states_; // States that have been expanded. + mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID + mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID + bool cache_gc_; // GC enabled. + size_t cache_limit_; // Number of bytes allowed before GC. + CacheStore *cache_store_; // The store of cached states. + bool new_cache_store_; // Was the store was created by class? + bool own_cache_store_; // Is the store owned by class? + + CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete; +}; + +// A CacheBaseImpl with the default cache state type. +template +class CacheImpl : public CacheBaseImpl> { + public: + using State = CacheState; + + CacheImpl() {} + + explicit CacheImpl(const CacheOptions &opts) + : CacheBaseImpl>(opts) {} + + CacheImpl(const CacheImpl &impl, bool preserve_cache = false) + : CacheBaseImpl(impl, preserve_cache) {} + + private: + CacheImpl &operator=(const CacheImpl &impl) = delete; +}; + +} // namespace internal + +// Use this to make a state iterator for a CacheBaseImpl-derived FST, which must +// have Arc and Store types defined. Note this iterator only returns those +// states reachable from the initial state, so consider implementing a +// class-specific one. +// +// This class may be derived from. +template +class CacheStateIterator : public StateIteratorBase { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + CacheStateIterator(const FST &fst, Impl *impl) + : fst_(fst), impl_(impl), s_(0) { + fst_.Start(); // Forces start state. + } + + bool Done() const final { + if (s_ < impl_->NumKnownStates()) return false; + for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates(); + u = impl_->MinUnexpandedState()) { + // Forces state expansion. + ArcIterator aiter(fst_, u); + aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); + for (; !aiter.Done(); aiter.Next()) { + impl_->UpdateNumKnownStates(aiter.Value().nextstate); + } + impl_->SetExpandedState(u); + if (s_ < impl_->NumKnownStates()) return false; + } + return true; + } + + StateId Value() const final { return s_; } + + void Next() final { ++s_; } + + void Reset() final { s_ = 0; } + + private: + const FST &fst_; + Impl *impl_; + StateId s_; +}; + +// Used to make an arc iterator for a CacheBaseImpl-derived FST, which must +// have Arc and State types defined. +template +class CacheArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + CacheArcIterator(Impl *impl, StateId s) : i_(0) { + state_ = impl->GetCacheStore()->GetMutableState(s); + state_->IncrRefCount(); + } + + ~CacheArcIterator() { state_->DecrRefCount(); } + + bool Done() const { return i_ >= state_->NumArcs(); } + + const Arc &Value() const { return state_->GetArc(i_); } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32 flags, uint32 mask) {} + + private: + const State *state_; + size_t i_; + + CacheArcIterator(const CacheArcIterator &) = delete; + CacheArcIterator &operator=(const CacheArcIterator &) = delete; +}; + +// Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST, +// which must have types Arc and Store defined. +template +class CacheMutableArcIterator + : public MutableArcIteratorBase { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + // User must call MutateCheck() in the constructor. + CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { + state_ = impl_->GetCacheStore()->GetMutableState(s_); + state_->IncrRefCount(); + } + + ~CacheMutableArcIterator() override { state_->DecrRefCount(); } + + bool Done() const final { return i_ >= state_->NumArcs(); } + + const Arc &Value() const final { return state_->GetArc(i_); } + + void Next() final { ++i_; } + + size_t Position() const final { return i_; } + + void Reset() final { i_ = 0; } + + void Seek(size_t a) final { i_ = a; } + + void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + size_t i_; + StateId s_; + Impl *impl_; + State *state_; + + CacheMutableArcIterator(const CacheMutableArcIterator &) = delete; + CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete; +}; + +// Wrap existing CacheStore implementation to use with ExpanderFst. +template +class ExpanderCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename CacheStore::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions()) + : store_(opts) {} + + template + State *FindOrExpand(Expander &expander, StateId s) { // NOLINT + auto *state = store_.GetMutableState(s); + if (state->Flags()) { + state->SetFlags(kCacheRecent, kCacheRecent); + } else { + StateBuilder builder(state); + expander.Expand(s, &builder); + state->SetFlags(kCacheFlags, kCacheFlags); + store_.SetArcs(state); + } + return state; + } + + private: + CacheStore store_; + + struct StateBuilder { + State *state; + + explicit StateBuilder(State *state_) : state(state_) {} + + void AddArc(const Arc &arc) { state->PushArc(arc); } + + void AddArc(Arc &&arc) { state->PushArc(std::move(arc)); } + + void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); } + }; +}; + +} // namespace fst + +#endif // FST_CACHE_H_ diff --git a/projects/llm_framework/include/fst/closure.h b/projects/llm_framework/include/fst/closure.h new file mode 100644 index 00000000..13beea9c --- /dev/null +++ b/projects/llm_framework/include/fst/closure.h @@ -0,0 +1,134 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the concatenative closure of an FST. + +#ifndef FST_CLOSURE_H_ +#define FST_CLOSURE_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Computes the concatenative closure. This version modifies its +// MutableFst input. If an FST transduces string x to y with weight a, +// then its closure transduces x to y with weight a, xx to yy with +// weight Times(a, a), xxx to yyy with with Times(Times(a, a), a), +// etc. If closure_type == CLOSURE_STAR, then the empty string is +// transduced to itself with weight Weight::One() as well. +// +// Complexity: +// +// Time: O(V) +// Space: O(V) +// +// where V is the number of states. +template +void Closure(MutableFst *fst, ClosureType closure_type) { + using Weight = typename Arc::Weight; + const auto props = fst->Properties(kFstProperties, false); + const auto start = fst->Start(); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + const auto s = siter.Value(); + const auto weight = fst->Final(s); + if (weight != Weight::Zero()) fst->AddArc(s, Arc(0, 0, weight, start)); + } + if (closure_type == CLOSURE_STAR) { + fst->ReserveStates(fst->NumStates() + 1); + const auto nstart = fst->AddState(); + fst->SetStart(nstart); + fst->SetFinal(nstart, Weight::One()); + if (start != kNoLabel) fst->AddArc(nstart, Arc(0, 0, Weight::One(), start)); + } + fst->SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR), + kFstProperties); +} + +// Computes the concatenative closure. This version modifies its +// RationalFst input. +template +void Closure(RationalFst *fst, ClosureType closure_type) { + fst->GetMutableImpl()->AddClosure(closure_type); +} + +struct ClosureFstOptions : RationalFstOptions { + ClosureType type; + + ClosureFstOptions(const RationalFstOptions &opts, + ClosureType type = CLOSURE_STAR) + : RationalFstOptions(opts), type(type) {} + + explicit ClosureFstOptions(ClosureType type = CLOSURE_STAR) : type(type) {} +}; + +// Computes the concatenative closure. This version is a delayed FST. If an FST +// transduces string x to y with weight a, then its closure transduces x to y +// with weight a, xx to yy with weight Times(a, a), xxx to yyy with weight +// Times(Times(a, a), a), etc. If closure_type == CLOSURE_STAR, then the empty +// string is transduced to itself with weight Weight::One() as well. +// +// Complexity: +// +// Time: O(v) +// Space: O(v) +// +// where v is the number of states visited. Constant time and space to visit an +// input state or arc is assumed and exclusive of caching. +template +class ClosureFst : public RationalFst { + public: + using Arc = A; + + ClosureFst(const Fst &fst, ClosureType closure_type) { + GetMutableImpl()->InitClosure(fst, closure_type); + } + + ClosureFst(const Fst &fst, const ClosureFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitClosure(fst, opts.type); + } + + // See Fst<>::Copy() for doc. + ClosureFst(const ClosureFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Gets a copy of this ClosureFst. See Fst<>::Copy() for further doc. + ClosureFst *Copy(bool safe = false) const override { + return new ClosureFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for ClosureFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const ClosureFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for ClosureFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ClosureFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdClosureFst = ClosureFst; + +} // namespace fst + +#endif // FST_CLOSURE_H_ diff --git a/projects/llm_framework/include/fst/compact-fst.h b/projects/llm_framework/include/fst/compact-fst.h new file mode 100644 index 00000000..402c87b7 --- /dev/null +++ b/projects/llm_framework/include/fst/compact-fst.h @@ -0,0 +1,1564 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST Class for memory-efficient representation of common types of +// FSTs: linear automata, acceptors, unweighted FSTs, ... + +#ifndef FST_COMPACT_FST_H_ +#define FST_COMPACT_FST_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include // For optional argument declarations +#include +#include +#include +#include + + +namespace fst { + +struct CompactFstOptions : public CacheOptions { + // The default caching behaviour is to do no caching. Most compactors are + // cheap and therefore we save memory by not doing caching. + CompactFstOptions() : CacheOptions(true, 0) {} + + explicit CompactFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} +}; + +// New upcoming (Fst) Compactor interface - currently used internally +// by CompactFstImpl. +// +// class Compactor { +// public: +// // Constructor from the Fst to be compacted. +// Compactor(const Fst &fst, ...); +// // Copy constructor +// Compactor(const Compactor &compactor, bool safe = false) +// // Default constructor (optional, see comment below). +// Compactor(); +// +// // Returns the start state, number of states, and total number of arcs +// // of the compacted Fst +// StateId Start() const; +// StateId NumStates() const; +// size_t NumArcs() const; +// +// // Accessor class for state attributes. +// class State { +// public: +// State(); // Required, corresponds to kNoStateId. +// State(const Compactor *c, StateId); // Accessor for StateId 's'. +// StateId GetStateId() const; +// Weight Final() const; +// size_t NumArcs() const; +// Arc GetArc(size_t i, uint32 f) const; +// }; +// +// // Modifies 'state' accessor to provide access to state id 's'. +// void SetState(StateId s, State *state); +// // Tests whether 'fst' can be compacted by this compactor. +// bool IsCompatible(const Fst &fst) const; +// // Return the properties that are always true for an fst +// // compacted using this compactor +// uint64 Properties() const; +// // Return a string identifying the type of compactor. +// static const string &Type(); +// // Return true if an error has occured. +// bool Error() const; +// // Writes a compactor to a file. +// bool Write(std::ostream &strm, const FstWriteOptions &opts) const; +// // Reads a compactor from a file. +// static Compactor*Read(std::istream &strm, const FstReadOptions &opts, +// const FstHeader &hdr); +// }; +// + +// Old (Arc) Compactor Interface: +// +// The ArcCompactor class determines how arcs and final weights are compacted +// and expanded. +// +// Final weights are treated as transitions to the superfinal state, i.e., +// ilabel = olabel = kNoLabel and nextstate = kNoStateId. +// +// There are two types of compactors: +// +// * Fixed out-degree compactors: 'compactor.Size()' returns a positive integer +// 's'. An FST can be compacted by this compactor only if each state has +// exactly 's' outgoing transitions (counting a non-Zero() final weight as a +// transition). A typical example is a compactor for string FSTs, i.e., +// 's == 1'. +// +// * Variable out-degree compactors: 'compactor.Size() == -1'. There are no +// out-degree restrictions for these compactors. +// +// Interface: +// +// class ArcCompactor { +// public: +// // Element is the type of the compacted transitions. +// using Element = ... +// +// // Returns the compacted representation of a transition 'arc' +// // at a state 's'. +// Element Compact(StateId s, const Arc &arc); +// +// // Returns the transition at state 's' represented by the compacted +// // transition 'e'. +// Arc Expand(StateId s, const Element &e) const; +// +// // Returns -1 for variable out-degree compactors, and the mandatory +// // out-degree otherwise. +// ssize_t Size() const; +// +// // Tests whether an FST can be compacted by this compactor. +// bool Compatible(const Fst &fst) const; +// +// // Returns the properties that are always true for an FST compacted using +// // this compactor +// uint64 Properties() const; +// +// // Returns a string identifying the type of compactor. +// static const string &Type(); +// +// // Writes a compactor to a file. +// bool Write(std::ostream &strm) const; +// +// // Reads a compactor from a file. +// static ArcCompactor *Read(std::istream &strm); +// +// // Default constructor (optional, see comment below). +// ArcCompactor(); +// }; +// +// The default constructor is only required for FST_REGISTER to work (i.e., +// enabling Convert() and the command-line utilities to work with this new +// compactor). However, a default constructor always needs to be specified for +// this code to compile, but one can have it simply raise an error when called, +// like so: +// +// Compactor::Compactor() { +// FSTERROR() << "Compactor: No default constructor"; +// } + +// Default implementation data for CompactFst, which can shared between +// otherwise independent copies. +// +// The implementation contains two arrays: 'states_' and 'compacts_'. +// +// For fixed out-degree compactors, the 'states_' array is unallocated. The +// 'compacts_' contains the compacted transitions. Its size is 'ncompacts_'. +// The outgoing transitions at a given state are stored consecutively. For a +// given state 's', its 'compactor.Size()' outgoing transitions (including +// superfinal transition when 's' is final), are stored in position +// ['s*compactor.Size()', '(s+1)*compactor.Size()'). +// +// For variable out-degree compactors, the states_ array has size +// 'nstates_ + 1' and contains pointers to positions into 'compacts_'. For a +// given state 's', the compacted transitions of 's' are stored in positions +// ['states_[s]', 'states_[s + 1]') in 'compacts_'. By convention, +// 'states_[nstates_] == ncompacts_'. +// +// In both cases, the superfinal transitions (when 's' is final, i.e., +// 'Final(s) != Weight::Zero()') are stored first. +// +// The unsigned type U is used to represent indices into the compacts_ array. +template +class DefaultCompactStore { + public: + DefaultCompactStore() + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) {} + + template + DefaultCompactStore(const Fst &fst, const Compactor &compactor); + + template + DefaultCompactStore(const Iterator &begin, const Iterator &end, + const Compactor &compactor); + + ~DefaultCompactStore() { + if (!states_region_) delete[] states_; + if (!compacts_region_) delete[] compacts_; + } + + template + static DefaultCompactStore *Read( + std::istream &strm, const FstReadOptions &opts, const FstHeader &hdr, + const Compactor &compactor); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const; + + Unsigned States(ssize_t i) const { return states_[i]; } + + const Element &Compacts(size_t i) const { return compacts_[i]; } + + size_t NumStates() const { return nstates_; } + + size_t NumCompacts() const { return ncompacts_; } + + size_t NumArcs() const { return narcs_; } + + ssize_t Start() const { return start_; } + + bool Error() const { return error_; } + + // Returns a string identifying the type of data storage container. + static const string &Type(); + + private: + std::unique_ptr states_region_; + std::unique_ptr compacts_region_; + Unsigned *states_; + Element *compacts_; + size_t nstates_; + size_t ncompacts_; + size_t narcs_; + ssize_t start_; + bool error_; +}; + +template +template +DefaultCompactStore::DefaultCompactStore( + const Fst &fst, const Compactor &compactor) + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + start_ = fst.Start(); + // Counts # of states and arcs. + StateId nfinals = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + const auto s = siter.Value(); + narcs_ += fst.NumArcs(s); + if (fst.Final(s) != Weight::Zero()) ++nfinals; + } + if (compactor.Size() == -1) { + states_ = new Unsigned[nstates_ + 1]; + ncompacts_ = narcs_ + nfinals; + compacts_ = new Element[ncompacts_]; + states_[nstates_] = ncompacts_; + } else { + states_ = nullptr; + ncompacts_ = nstates_ * compactor.Size(); + if ((narcs_ + nfinals) != ncompacts_) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } + compacts_ = new Element[ncompacts_]; + } + size_t pos = 0; + size_t fpos = 0; + for (size_t s = 0; s < nstates_; ++s) { + fpos = pos; + if (compactor.Size() == -1) states_[s] = pos; + if (fst.Final(s) != Weight::Zero()) { + compacts_[pos++] = compactor.Compact( + s, Arc(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + compacts_[pos++] = compactor.Compact(s, aiter.Value()); + } + if ((compactor.Size() != -1) && (pos != fpos + compactor.Size())) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } + } + if (pos != ncompacts_) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } +} + +template +template +DefaultCompactStore::DefaultCompactStore( + const Iterator &begin, const Iterator &end, const Compactor &compactor) + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + using Arc = typename Compactor::Arc; + using Weight = typename Arc::Weight; + if (compactor.Size() != -1) { + ncompacts_ = std::distance(begin, end); + if (compactor.Size() == 1) { + // For strings, allows implicit final weight. Empty input is the empty + // string. + if (ncompacts_ == 0) { + ++ncompacts_; + } else { + const auto arc = + compactor.Expand(ncompacts_ - 1, *(begin + (ncompacts_ - 1))); + if (arc.ilabel != kNoLabel) ++ncompacts_; + } + } + if (ncompacts_ % compactor.Size()) { + FSTERROR() << "DefaultCompactStore: Size of input container incompatible" + << " with compactor"; + error_ = true; + return; + } + if (ncompacts_ == 0) return; + start_ = 0; + nstates_ = ncompacts_ / compactor.Size(); + compacts_ = new Element[ncompacts_]; + size_t i = 0; + Iterator it = begin; + for (; it != end; ++it, ++i) { + compacts_[i] = *it; + if (compactor.Expand(i, *it).ilabel != kNoLabel) ++narcs_; + } + if (i < ncompacts_) { + compacts_[i] = compactor.Compact( + i, Arc(kNoLabel, kNoLabel, Weight::One(), kNoStateId)); + } + } else { + if (std::distance(begin, end) == 0) return; + // Count # of states, arcs and compacts. + auto it = begin; + for (size_t i = 0; it != end; ++it, ++i) { + const auto arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + ++narcs_; + ++ncompacts_; + } else { + ++nstates_; + if (arc.weight != Weight::Zero()) ++ncompacts_; + } + } + start_ = 0; + compacts_ = new Element[ncompacts_]; + states_ = new Unsigned[nstates_ + 1]; + states_[nstates_] = ncompacts_; + size_t i = 0; + size_t s = 0; + for (it = begin; it != end; ++it) { + const auto arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + compacts_[i++] = *it; + } else { + states_[s++] = i; + if (arc.weight != Weight::Zero()) compacts_[i++] = *it; + } + } + if ((s != nstates_) || (i != ncompacts_)) { + FSTERROR() << "DefaultCompactStore: Ill-formed input container"; + error_ = true; + return; + } + } +} + +template +template +DefaultCompactStore + *DefaultCompactStore::Read(std::istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr, + const Compactor &compactor) { + std::unique_ptr> data( + new DefaultCompactStore()); + data->start_ = hdr.Start(); + data->nstates_ = hdr.NumStates(); + data->narcs_ = hdr.NumArcs(); + if (compactor.Size() == -1) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Read: Alignment failed: " + << opts.source; + return nullptr; + } + auto b = (data->nstates_ + 1) * sizeof(Unsigned); + data->states_region_.reset(MappedFile::Map( + &strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !data->states_region_) { + LOG(ERROR) << "DefaultCompactStore::Read: Read failed: " << opts.source; + return nullptr; + } + data->states_ = + static_cast(data->states_region_->mutable_data()); + } else { + data->states_ = nullptr; + } + data->ncompacts_ = compactor.Size() == -1 ? data->states_[data->nstates_] + : data->nstates_ * compactor.Size(); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Read: Alignment failed: " + << opts.source; + return nullptr; + } + size_t b = data->ncompacts_ * sizeof(Element); + data->compacts_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !data->compacts_region_) { + LOG(ERROR) << "DefaultCompactStore::Read: Read failed: " << opts.source; + return nullptr; + } + data->compacts_ = + static_cast(data->compacts_region_->mutable_data()); + return data.release(); +} + +template +bool DefaultCompactStore::Write( + std::ostream &strm, const FstWriteOptions &opts) const { + if (states_) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Write: Alignment failed: " + << opts.source; + return false; + } + strm.write(reinterpret_cast(states_), + (nstates_ + 1) * sizeof(Unsigned)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Write: Alignment failed: " + << opts.source; + return false; + } + strm.write(reinterpret_cast(compacts_), ncompacts_ * sizeof(Element)); + strm.flush(); + if (!strm) { + LOG(ERROR) << "DefaultCompactStore::Write: Write failed: " << opts.source; + return false; + } + return true; +} + +template +const string &DefaultCompactStore::Type() { + static const string *const type = new string("compact"); + return *type; +} + +template class DefaultCompactState; + +// Wraps an arc compactor and a compact store as a new Fst compactor. +template > +class DefaultCompactor { + public: + using ArcCompactor = C; + using Unsigned = U; + using CompactStore = S; + using Element = typename C::Element; + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using State = DefaultCompactState; + friend State; + + DefaultCompactor() + : arc_compactor_(nullptr), compact_store_(nullptr) {} + + // Constructs from Fst. + DefaultCompactor(const Fst &fst, + std::shared_ptr arc_compactor) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::make_shared(fst, *arc_compactor_)) {} + + DefaultCompactor(const Fst &fst, + std::shared_ptr> compactor) + : arc_compactor_(compactor->arc_compactor_), + compact_store_(compactor->compact_store_ == nullptr ? + std::make_shared(fst, *arc_compactor_) : + compactor->compact_store_) {} + + // Constructs from CompactStore. + DefaultCompactor(std::shared_ptr arc_compactor, + std::shared_ptr compact_store) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::move(compact_store)) {} + + // Constructs from set of compact elements (when arc_compactor.Size() != -1). + template + DefaultCompactor(const Iterator &b, const Iterator &e, + std::shared_ptr arc_compactor) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::make_shared(b, e, *arc_compactor_)) {} + + // Copy constructor. + DefaultCompactor(const DefaultCompactor &compactor) + : arc_compactor_(std::make_shared(*compactor.GetArcCompactor())), + compact_store_(compactor.SharedCompactStore()) {} + + template + explicit DefaultCompactor(const DefaultCompactor &compactor) + : arc_compactor_(std::make_shared(*compactor.GetArcCompactor())), + compact_store_(compactor.SharedCompactStore()) {} + + StateId Start() const { return compact_store_->Start(); } + StateId NumStates() const { return compact_store_->NumStates(); } + size_t NumArcs() const { return compact_store_->NumArcs(); } + + void SetState(StateId s, State *state) const { + if (state->GetStateId() != s) state->Set(this, s); + } + + static DefaultCompactor *Read(std::istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr) { + std::shared_ptr arc_compactor(C::Read(strm)); + if (arc_compactor == nullptr) return nullptr; + std::shared_ptr compact_store(S::Read(strm, opts, hdr, *arc_compactor)); + if (compact_store == nullptr) return nullptr; + return new DefaultCompactor(arc_compactor, compact_store); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + return arc_compactor_->Write(strm) && compact_store_->Write(strm, opts); + } + + uint64 Properties() const { return arc_compactor_->Properties(); } + + bool IsCompatible(const Fst &fst) const { + return arc_compactor_->Compatible(fst); + } + + bool Error() const { return compact_store_->Error(); } + + bool HasFixedOutdegree() const { return arc_compactor_->Size() != -1; } + + static const string &Type() { + static const string *const type = [] { + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) type += std::to_string(8 * sizeof(U)); + type += "_"; + type += C::Type(); + if (CompactStore::Type() != "compact") { + type += "_"; + type += CompactStore::Type(); + } + return new string(type); + }(); + return *type; + } + + const ArcCompactor *GetArcCompactor() const { return arc_compactor_.get(); } + CompactStore *GetCompactStore() const { return compact_store_.get(); } + + std::shared_ptr SharedArcCompactor() const { + return arc_compactor_; + } + + std::shared_ptr SharedCompactStore() const { + return compact_store_; + } + + // TODO(allauzen): remove dependencies on this method and make private. + Arc ComputeArc(StateId s, Unsigned i, uint32 f) const { + return arc_compactor_->Expand(s, compact_store_->Compacts(i), f); + } + + private: + std::pair CompactsRange(StateId s) const { + std::pair range; + if (HasFixedOutdegree()) { + range.first = s * arc_compactor_->Size(); + range.second = arc_compactor_->Size(); + } else { + range.first = compact_store_->States(s); + range.second = compact_store_->States(s + 1) - range.first; + } + return range; + } + + private: + std::shared_ptr arc_compactor_; + std::shared_ptr compact_store_; +}; + +// Default implementation of state attributes accessor class for +// DefaultCompactor. Use of efficient specialization strongly encouraged. +template +class DefaultCompactState { + public: + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DefaultCompactState() = default; + + DefaultCompactState(const DefaultCompactor *compactor, StateId s) + : compactor_(compactor), + s_(s), + range_(compactor->CompactsRange(s)), + has_final_( + range_.second != 0 && + compactor->ComputeArc(s, range_.first, + kArcILabelValue).ilabel == kNoLabel) { + if (has_final_) { + ++range_.first; + --range_.second; + } + } + + void Set(const DefaultCompactor *compactor, StateId s) { + compactor_ = compactor; + s_ = s; + range_ = compactor->CompactsRange(s); + if (range_.second != 0 && + compactor->ComputeArc(s, range_.first, kArcILabelValue).ilabel + == kNoLabel) { + has_final_ = true; + ++range_.first; + --range_.second; + } else { + has_final_ = false; + } + } + + StateId GetStateId() const { return s_; } + + Weight Final() const { + if (!has_final_) return Weight::Zero(); + return compactor_->ComputeArc(s_, range_.first - 1, kArcWeightValue).weight; + } + + size_t NumArcs() const { return range_.second; } + + Arc GetArc(size_t i, uint32 f) const { + return compactor_->ComputeArc(s_, range_.first + i, f); + } + + private: + const DefaultCompactor *compactor_ = nullptr; // borrowed ref. + StateId s_ = kNoStateId; + std::pair range_ = {0, 0}; + bool has_final_ = false; +}; + +// Specialization for DefaultCompactStore. +template +class DefaultCompactState> { + public: + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using CompactStore = DefaultCompactStore; + + DefaultCompactState() = default; + + DefaultCompactState( + const DefaultCompactor *compactor, StateId s) + : arc_compactor_(compactor->GetArcCompactor()), s_(s) { + Init(compactor); + } + + void Set(const DefaultCompactor *compactor, StateId s) { + arc_compactor_ = compactor->GetArcCompactor(); + s_ = s; + has_final_ = false; + Init(compactor); + } + + StateId GetStateId() const { return s_; } + + Weight Final() const { + if (!has_final_) return Weight::Zero(); + return arc_compactor_->Expand(s_, *(compacts_ - 1), kArcWeightValue).weight; + } + + size_t NumArcs() const { return num_arcs_; } + + Arc GetArc(size_t i, uint32 f) const { + return arc_compactor_->Expand(s_, compacts_[i], f); + } + + private: + void Init(const DefaultCompactor *compactor) { + const auto *store = compactor->GetCompactStore(); + U offset; + if (!compactor->HasFixedOutdegree()) { // Variable out-degree compactor. + offset = store->States(s_); + num_arcs_ = store->States(s_ + 1) - offset; + } else { // Fixed out-degree compactor. + offset = s_ * arc_compactor_->Size(); + num_arcs_ = arc_compactor_->Size(); + } + if (num_arcs_ > 0) { + compacts_ = &(store->Compacts(offset)); + if (arc_compactor_->Expand(s_, *compacts_, kArcILabelValue).ilabel + == kNoStateId) { + ++compacts_; + --num_arcs_; + has_final_ = true; + } + } + } + + private: + const C *arc_compactor_ = nullptr; // Borrowed reference. + const typename C::Element *compacts_ = nullptr; // Borrowed reference. + StateId s_ = kNoStateId; + U num_arcs_ = 0; + bool has_final_ = false; +}; + +template +class CompactFst; + +template +void Cast(const F &, G *); + +namespace internal { + +// Implementation class for CompactFst, which contains parametrizeable +// Fst data storage (DefaultCompactStore by default) and Fst cache. +template > +class CompactFstImpl + : public CacheBaseImpl { + public: + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + using Compactor = C; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using ImplBase = CacheBaseImpl; + using ImplBase::PushArc; + using ImplBase::HasArcs; + using ImplBase::HasFinal; + using ImplBase::HasStart; + using ImplBase::SetArcs; + using ImplBase::SetFinal; + using ImplBase::SetStart; + + CompactFstImpl() + : ImplBase(CompactFstOptions()), + compactor_() { + SetType(Compactor::Type()); + SetProperties(kNullProperties | kStaticProperties); + } + + CompactFstImpl(const Fst &fst, std::shared_ptr compactor, + const CompactFstOptions &opts) + : ImplBase(opts), + compactor_(std::make_shared(fst, compactor)) { + SetType(Compactor::Type()); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + if (compactor_->Error()) SetProperties(kError, kError); + uint64 copy_properties = fst.Properties(kMutable, false) ? + fst.Properties(kCopyProperties, true): + CheckProperties(fst, + kCopyProperties & ~kWeightedCycles & ~kUnweightedCycles, + kCopyProperties); + if ((copy_properties & kError) || !compactor_->IsCompatible(fst)) { + FSTERROR() << "CompactFstImpl: Input Fst incompatible with compactor"; + SetProperties(kError, kError); + return; + } + SetProperties(copy_properties | kStaticProperties); + } + + CompactFstImpl(std::shared_ptr compactor, + const CompactFstOptions &opts) + : ImplBase(opts), + compactor_(compactor) { + SetType(Compactor::Type()); + SetProperties(kStaticProperties | compactor_->Properties()); + if (compactor_->Error()) SetProperties(kError, kError); + } + + CompactFstImpl(const CompactFstImpl &impl) + : ImplBase(impl), + compactor_(impl.compactor_ == nullptr ? + std::make_shared() : + std::make_shared(*impl.compactor_)) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + // Allows to change the cache store from OtherI to I. + template + CompactFstImpl(const CompactFstImpl &impl) + : ImplBase(CacheOptions(impl.GetCacheGc(), impl.GetCacheLimit())), + compactor_(impl.compactor_ == nullptr ? + std::make_shared() : + std::make_shared(*impl.compactor_)) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(compactor_->Start()); + return ImplBase::Start(); + } + + Weight Final(StateId s) { + if (HasFinal(s)) return ImplBase::Final(s); + compactor_->SetState(s, &state_); + return state_.Final(); + } + + StateId NumStates() const { + if (Properties(kError)) return 0; + return compactor_->NumStates(); + } + + size_t NumArcs(StateId s) { + if (HasArcs(s)) return ImplBase::NumArcs(s); + compactor_->SetState(s, &state_); + return state_.NumArcs(); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kILabelSorted)) Expand(s); + if (HasArcs(s)) return ImplBase::NumInputEpsilons(s); + return CountEpsilons(s, false); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kOLabelSorted)) Expand(s); + if (HasArcs(s)) return ImplBase::NumOutputEpsilons(s); + return CountEpsilons(s, true); + } + + size_t CountEpsilons(StateId s, bool output_epsilons) { + compactor_->SetState(s, &state_); + const uint32 f = output_epsilons ? kArcOLabelValue : kArcILabelValue; + size_t num_eps = 0; + for (size_t i = 0; i < state_.NumArcs(); ++i) { + const auto& arc = state_.GetArc(i, f); + const auto label = output_epsilons ? arc.olabel : arc.ilabel; + if (label == 0) + ++num_eps; + else if (label > 0) + break; + } + return num_eps; + } + + static CompactFstImpl *Read( + std::istream &strm, const FstReadOptions &opts) { + std::unique_ptr> impl( + new CompactFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + return nullptr; + } + // Ensures compatibility. + if (hdr.Version() == kAlignedFileVersion) { + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + } + impl->compactor_ = std::shared_ptr( + Compactor::Read(strm, opts, hdr)); + if (!impl->compactor_) { + return nullptr; + } + return impl.release(); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(compactor_->Start()); + hdr.SetNumStates(compactor_->NumStates()); + hdr.SetNumArcs(compactor_->NumArcs()); + // Ensures compatibility. + const auto file_version = opts.align ? kAlignedFileVersion : kFileVersion; + WriteHeader(strm, opts, file_version, &hdr); + return compactor_->Write(strm, opts); + } + + // Provides information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = compactor_->NumStates(); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + ImplBase::InitArcIterator(s, data); + } + + void Expand(StateId s) { + compactor_->SetState(s, &state_); + for (size_t i = 0; i < state_.NumArcs(); ++i) + PushArc(s, state_.GetArc(i, kArcValueFlags)); + SetArcs(s); + if (!HasFinal(s)) SetFinal(s, state_.Final()); + } + + const Compactor *GetCompactor() const { return compactor_.get(); } + std::shared_ptr SharedCompactor() const { return compactor_; } + void SetCompactor(std::shared_ptr compactor) { + // TODO(allauzen): is this correct? is this needed? + // TODO(allauzen): consider removing and forcing this through direct calls + // to compactor. + compactor_ = compactor; + } + + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded; + + protected: + template + explicit CompactFstImpl( + const CompactFstImpl &impl) + : compactor_(std::make_shared(*impl.GetCompactor())) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + private: + // Allows access during write. + template + friend class ::fst::CompactFst; // allow access during write. + + // Current unaligned file format version. + static constexpr int kFileVersion = 2; + // Current aligned file format version. + static constexpr int kAlignedFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + std::shared_ptr compactor_; + typename Compactor::State state_; +}; + +template +constexpr uint64 CompactFstImpl::kStaticProperties; + +template +constexpr int CompactFstImpl::kFileVersion; + +template +constexpr int CompactFstImpl::kAlignedFileVersion; + +template +constexpr int CompactFstImpl::kMinFileVersion; + +} // namespace internal + +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToExpandedFst. The Unsigned type +// is used to represent indices into the compact arc array. (Template +// argument defaults are declared in fst-decl.h.) +template +class CompactFst + : public ImplToExpandedFst, + CacheStore>> { + public: + template + void friend Cast(const F &, G *); + + using Arc = A; + using StateId = typename A::StateId; + using Compactor = DefaultCompactor; + using Impl = internal::CompactFstImpl; + using Store = CacheStore; // for CacheArcIterator + + friend class StateIterator< + CompactFst>; + friend class ArcIterator< + CompactFst>; + + CompactFst() : ImplToExpandedFst(std::make_shared()) {} + + // If data is not nullptr, it is assumed to be already initialized. + explicit CompactFst( + const Fst &fst, + const ArcCompactor &compactor = ArcCompactor(), + const CompactFstOptions &opts = CompactFstOptions(), + std::shared_ptr data = std::shared_ptr()) + : ImplToExpandedFst( + std::make_shared( + fst, + std::make_shared( + std::make_shared(compactor), data), + opts)) {} + + // If data is not nullptr, it is assumed to be already initialized. + CompactFst( + const Fst &fst, + std::shared_ptr compactor, + const CompactFstOptions &opts = CompactFstOptions(), + std::shared_ptr data = std::shared_ptr()) + : ImplToExpandedFst( + std::make_shared(fst, + std::make_shared(compactor, data), + opts)) {} + + // The following 2 constructors take as input two iterators delimiting a set + // of (already) compacted transitions, starting with the transitions out of + // the initial state. The format of the input differs for fixed out-degree + // and variable out-degree compactors. + // + // - For fixed out-degree compactors, the final weight (encoded as a + // compacted transition) needs to be given only for final states. All strings + // (compactor of size 1) will be assume to be terminated by a final state + // even when the final state is not implicitely given. + // + // - For variable out-degree compactors, the final weight (encoded as a + // compacted transition) needs to be given for all states and must appeared + // first in the list (for state s, final weight of s, followed by outgoing + // transitons in s). + // + // These 2 constructors allows the direct construction of a CompactFst + // without first creating a more memory-hungry regular FST. This is useful + // when memory usage is severely constrained. + template + explicit CompactFst(const Iterator &begin, const Iterator &end, + const ArcCompactor &compactor = ArcCompactor(), + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst( + std::make_shared( + std::make_shared( + begin, end, std::make_shared(compactor)), + opts)) {} + + template + CompactFst(const Iterator &begin, const Iterator &end, + std::shared_ptr compactor, + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst( + std::make_shared( + std::make_shared(begin, end, compactor), opts)) {} + + // See Fst<>::Copy() for doc. + CompactFst( + const CompactFst + &fst, + bool safe = false) + : ImplToExpandedFst(fst, safe) {} + + // Get a copy of this CompactFst. See Fst<>::Copy() for further doc. + CompactFst *Copy( + bool safe = false) const override { + return new CompactFst( + *this, safe); + } + + // Read a CompactFst from an input stream; return nullptr on error + static CompactFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new CompactFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a CompactFst from a file; return nullptr on error + // Empty filename reads from standard input + static CompactFst *Read( + const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new CompactFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, const ArcCompactor &compactor, + std::ostream &strm, const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new SortedMatcher< + CompactFst>( + *this, match_type); + } + + template + void SetCompactElements(const Iterator &b, const Iterator &e) { + GetMutableImpl()->SetCompactor(std::make_shared( + b, e, std::make_shared())); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; + + explicit CompactFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + // Use overloading to extract the type of the argument. + static Impl *GetImplIfCompactFst( + const CompactFst + &compact_fst) { + return compact_fst.GetImpl(); + } + + // This does not give privileged treatment to subclasses of CompactFst. + template + static Impl *GetImplIfCompactFst(const NonCompactFst &fst) { + return nullptr; + } + + CompactFst &operator=(const CompactFst &fst) = delete; +}; + +// Writes FST in Compact format, with a possible pass over the machine before +// writing to compute the number of states and arcs. +template +template +bool CompactFst::WriteFst( + const FST &fst, const ArcCompactor &compactor, std::ostream &strm, + const FstWriteOptions &opts) { + using Arc = A; + using Weight = typename A::Weight; + using Element = typename ArcCompactor::Element; + const auto file_version = + opts.align ? Impl::kAlignedFileVersion : Impl::kFileVersion; + size_t num_arcs = -1; + size_t num_states = -1; + auto first_pass_compactor = compactor; + if (auto *impl = GetImplIfCompactFst(fst)) { + num_arcs = impl->GetCompactor()->GetCompactStore()->NumArcs(); + num_states = impl->GetCompactor()->GetCompactStore()->NumStates(); + first_pass_compactor = *impl->GetCompactor()->GetArcCompactor(); + } else { + // A first pass is needed to compute the state of the compactor, which + // is saved ahead of the rest of the data structures. This unfortunately + // means forcing a complete double compaction when writing in this format. + // TODO(allauzen): eliminate mutable state from compactors. + num_arcs = 0; + num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + ++num_states; + if (fst.Final(s) != Weight::Zero()) { + first_pass_compactor.Compact( + s, Arc(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + ++num_arcs; + first_pass_compactor.Compact(s, aiter.Value()); + } + } + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "compact"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + type += "_"; + type += ArcCompactor::Type(); + if (CompactStore::Type() != "compact") { + type += "_"; + type += CompactStore::Type(); + } + const auto copy_properties = fst.Properties(kCopyProperties, true); + if ((copy_properties & kError) || !compactor.Compatible(fst)) { + FSTERROR() << "Fst incompatible with compactor"; + return false; + } + uint64 properties = copy_properties | Impl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, type, + properties, &hdr); + first_pass_compactor.Write(strm); + if (first_pass_compactor.Size() == -1) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + Unsigned compacts = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + strm.write(reinterpret_cast(&compacts), sizeof(compacts)); + if (fst.Final(s) != Weight::Zero()) { + ++compacts; + } + compacts += fst.NumArcs(s); + } + strm.write(reinterpret_cast(&compacts), sizeof(compacts)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + const auto &second_pass_compactor = compactor; + Element element; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) { + element = second_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + strm.write(reinterpret_cast(&element), sizeof(element)); + } + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + element = second_pass_compactor.Compact(s, aiter.Value()); + strm.write(reinterpret_cast(&element), sizeof(element)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "CompactFst write failed: " << opts.source; + return false; + } + return true; +} + +// Specialization for CompactFst; see generic version in fst.h for sample +// usage (but use the CompactFst type!). This version should inline. +template +class StateIterator< + CompactFst> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator( + const CompactFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; +}; + +// Specialization for CompactFst. Never caches, +// always iterates over the underlying compact elements. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + using Element = typename ArcCompactor::Element; + using Compactor = DefaultCompactor; + using State = typename Compactor::State; + + ArcIterator(const CompactFst &fst, + StateId s) + : state_(fst.GetImpl()->GetCompactor(), s), + pos_(0), + flags_(kArcValueFlags) {} + + bool Done() const { return pos_ >= state_.NumArcs(); } + + const Arc &Value() const { + arc_ = state_.GetArc(pos_, flags_); + return arc_; + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + State state_; + size_t pos_; + mutable Arc arc_; + uint32 flags_; +}; + +// ArcCompactor for unweighted string FSTs. +template +class StringCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = Label; + + Element Compact(StateId s, const Arc &arc) const { return arc.ilabel; } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p, p, Weight::One(), p != kNoLabel ? s + 1 : kNoStateId); + } + + constexpr ssize_t Size() const { return 1; } + + constexpr uint64 Properties() const { + return kString | kAcceptor | kUnweighted; + } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("string"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static StringCompactor *Read(std::istream &strm) { + return new StringCompactor; + } +}; + +// ArcCompactor for weighted string FSTs. +template +class WeightedStringCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(arc.ilabel, arc.weight); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, p.second, + p.first != kNoLabel ? s + 1 : kNoStateId); + } + + constexpr ssize_t Size() const { return 1; } + + constexpr uint64 Properties() const { return kString | kAcceptor; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("weighted_string"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static WeightedStringCompactor *Read(std::istream &strm) { + return new WeightedStringCompactor; + } +}; + +// ArcCompactor for unweighted acceptor FSTs. +template +class UnweightedAcceptorCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(arc.ilabel, arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, Weight::One(), p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kAcceptor | kUnweighted; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("unweighted_acceptor"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static UnweightedAcceptorCompactor *Read(std::istream &istrm) { + return new UnweightedAcceptorCompactor; + } +}; + +// ArcCompactor for weighted acceptor FSTs. +template +class AcceptorCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair, StateId>; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(std::make_pair(arc.ilabel, arc.weight), + arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.first, p.first.second, p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kAcceptor; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("acceptor"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static AcceptorCompactor *Read(std::istream &strm) { + return new AcceptorCompactor; + } +}; + +// ArcCompactor for unweighted FSTs. +template +class UnweightedCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair, StateId>; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(std::make_pair(arc.ilabel, arc.olabel), + arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.second, Weight::One(), p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kUnweighted; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("unweighted"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static UnweightedCompactor *Read(std::istream &strm) { + return new UnweightedCompactor; + } +}; + +template +using CompactStringFst = CompactFst, Unsigned>; + +template +using CompactWeightedStringFst = + CompactFst, Unsigned>; + +template +using CompactAcceptorFst = CompactFst, Unsigned>; + +template +using CompactUnweightedFst = + CompactFst, Unsigned>; + +template +using CompactUnweightedAcceptorFst = + CompactFst, Unsigned>; + +using StdCompactStringFst = CompactStringFst; + +using StdCompactWeightedStringFst = CompactWeightedStringFst; + +using StdCompactAcceptorFst = CompactAcceptorFst; + +using StdCompactUnweightedFst = CompactUnweightedFst; + +using StdCompactUnweightedAcceptorFst = + CompactUnweightedAcceptorFst; + +} // namespace fst + +#endif // FST_COMPACT_FST_H_ diff --git a/projects/llm_framework/include/fst/compat.h b/projects/llm_framework/include/fst/compat.h new file mode 100644 index 00000000..73ed5737 --- /dev/null +++ b/projects/llm_framework/include/fst/compat.h @@ -0,0 +1,130 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_LIB_COMPAT_H_ +#define FST_LIB_COMPAT_H_ + +#include +#include +#include +#include +#include +#include + +// Makes copy constructor and operator= private +// Deprecated: now just use =delete. +#define DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator=(const type&) + +#if defined(__GNUC__) || defined(__clang__) +#define OPENFST_DEPRECATED(message) __attribute__((deprecated(message))) +#elif defined(_MSC_VER) +#define OPENFST_DEPRECATED(message) __declspec(deprecated(message)) +#else +#define OPENFST_DEPRECATED(message) +#endif + +#include +#include +#include +#include +#include +#include + +using std::string; + +void FailedNewHandler(); + +#ifdef _MSC_VER +#include +const char* basename(const char* path); +#define __builtin_popcount __popcnt + +#ifdef _M_X64 +// Using 64-bit MSVC intrinsics. +#define __builtin_popcountll __popcnt64 +inline unsigned int __builtin_ctzll(std::uint64_t w) { + unsigned long v; + return _BitScanForward64(&v, std::uint32_t(w)) ? v : 0; +} +#else +// Using 32-bit MSVC intrinsics. +inline unsigned int __builtin_popcountll(std::uint64_t w) { + return __popcnt(std::uint32_t(w)) + __popcnt(std::uint32_t(w >> 32)); +} +inline unsigned int __builtin_ctzll(std::uint64_t w) { + unsigned long v; + return (_BitScanForward(&v, std::uint32_t(w)) ? v : + _BitScanForward(&v, std::uint32_t(w >> 32)) ? v + 32 : 0); +} +#endif // _M_X64 +#endif // _MSC_VER + +namespace fst { + +// Downcasting. +template +inline To down_cast(From* f) { return static_cast(f); } + +// Bitcasting. +template +inline Dest bit_cast(const Source &source) { + static_assert(sizeof(Dest) == sizeof(Source), + "Bitcasting unsafe for specified types"); + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +// Check sums +class CheckSummer { + public: + CheckSummer() : count_(0) { + check_sum_.resize(kCheckSumLength, '\0'); + } + + void Reset() { + count_ = 0; + for (int i = 0; i < kCheckSumLength; ++i) check_sum_[i] = '\0'; + } + + void Update(void const *data, int size) { + const char *p = reinterpret_cast(data); + for (int i = 0; i < size; ++i) { + check_sum_[(count_++) % kCheckSumLength] ^= p[i]; + } + } + + void Update(string const &data) { + for (int i = 0; i < data.size(); ++i) { + check_sum_[(count_++) % kCheckSumLength] ^= data[i]; + } + } + + string Digest() { return check_sum_; } + + private: + static const int kCheckSumLength = 32; + int count_; + string check_sum_; + + CheckSummer(const CheckSummer &) = delete; + CheckSummer &operator=(const CheckSummer &) = delete; +}; + +} // namespace fst + +#endif // FST_LIB_COMPAT_H_ diff --git a/projects/llm_framework/include/fst/complement.h b/projects/llm_framework/include/fst/complement.h new file mode 100644 index 00000000..64eebc03 --- /dev/null +++ b/projects/llm_framework/include/fst/complement.h @@ -0,0 +1,277 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to complement an FST. + +#ifndef FST_COMPLEMENT_H_ +#define FST_COMPLEMENT_H_ + +#include +#include +#include +#include + +#include +#include + + +namespace fst { + +template +class ComplementFst; + +namespace internal { + +// Implementation of delayed ComplementFst. The algorithm used completes the +// (deterministic) FSA and then exchanges final and non-final states. +// Completion, i.e. ensuring that all labels can be read from every state, is +// accomplished by using ρ-labels, which match all labels that are otherwise +// not found leaving a state. The first state in the output is reserved to be a +// new state that is the destination of all ρ-labels. Each remaining output +// state s corresponds to input state s - 1. The first arc in the output at +// these states is the ρ-label, the remaining arcs correspond to the input +// arcs. +template +class ComplementFstImpl : public FstImpl { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + friend class StateIterator>; + friend class ArcIterator>; + + explicit ComplementFstImpl(const Fst &fst) : fst_(fst.Copy()) { + SetType("complement"); + uint64 props = fst.Properties(kILabelSorted, false); + SetProperties(ComplementProperties(props), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + ComplementFstImpl(const ComplementFstImpl &impl) + : fst_(impl.fst_->Copy()) { + SetType("complement"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() const { + if (Properties(kError)) return kNoStateId; + auto start = fst_->Start(); + return start != kNoStateId ? start + 1 : 0; + } + + // Exchange final and non-final states; makes ρ-destination state final. + Weight Final(StateId s) const { + if (s == 0 || fst_->Final(s - 1) == Weight::Zero()) { + return Weight::One(); + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId s) const { + return s == 0 ? 1 : fst_->NumArcs(s - 1) + 1; + } + + size_t NumInputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumInputEpsilons(s - 1); + } + + size_t NumOutputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumOutputEpsilons(s - 1); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && fst_->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + private: + std::unique_ptr> fst_; +}; + +} // namespace internal + +// Complements an automaton. This is a library-internal operation that +// introduces a (negative) ρ-label; use Difference/DifferenceFst in user code, +// which will not see this label. This version is a delayed FST. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class ComplementFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Impl = internal::ComplementFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + + explicit ComplementFst(const Fst &fst) + : ImplToFst(std::make_shared(fst)) { + static constexpr auto props = + kUnweighted | kNoEpsilons | kIDeterministic | kAcceptor; + if (fst.Properties(props, true) != props) { + FSTERROR() << "ComplementFst: Argument not an unweighted " + << "epsilon-free deterministic acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + ComplementFst(const ComplementFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this FST. See Fst<>::Copy() for further doc. + ComplementFst *Copy(bool safe = false) const override { + return new ComplementFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + inline void InitArcIterator(StateId s, + ArcIteratorData *data) const override; + + // Label that represents the ρ-transition; we use a negative value private to + // the library and which will preserve FST label sort order. + static const Label kRhoLabel = -2; + + private: + using ImplToFst::GetImpl; + + ComplementFst &operator=(const ComplementFst &) = delete; +}; + +template +const typename Arc::Label ComplementFst::kRhoLabel; + +// Specialization for ComplementFst. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const ComplementFst &fst) + : siter_(*fst.GetImpl()->fst_), s_(0) {} + + bool Done() const final { return s_ > 0 && siter_.Done(); } + + StateId Value() const final { return s_; } + + void Next() final { + if (s_ != 0) siter_.Next(); + ++s_; + } + + void Reset() final { + siter_.Reset(); + s_ = 0; + } + + private: + StateIterator> siter_; + StateId s_; +}; + +// Specialization for ComplementFst. +template +class ArcIterator> : public ArcIteratorBase { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + ArcIterator(const ComplementFst &fst, StateId s) : s_(s), pos_(0) { + if (s_ != 0) { + aiter_.reset(new ArcIterator>(*fst.GetImpl()->fst_, s - 1)); + } + } + + bool Done() const final { + if (s_ != 0) { + return pos_ > 0 && aiter_->Done(); + } else { + return pos_ > 0; + } + } + + // Adds the ρ-label to the ρ destination state. + const Arc &Value() const final { + if (pos_ == 0) { + arc_.ilabel = arc_.olabel = ComplementFst::kRhoLabel; + arc_.weight = Weight::One(); + arc_.nextstate = 0; + } else { + arc_ = aiter_->Value(); + ++arc_.nextstate; + } + return arc_; + } + + void Next() final { + if (s_ != 0 && pos_ > 0) aiter_->Next(); + ++pos_; + } + + size_t Position() const final { return pos_; } + + void Reset() final { + if (s_ != 0) aiter_->Reset(); + pos_ = 0; + } + + void Seek(size_t a) final { + if (s_ != 0) { + if (a == 0) { + aiter_->Reset(); + } else { + aiter_->Seek(a - 1); + } + } + pos_ = a; + } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + std::unique_ptr>> aiter_; + StateId s_; + size_t pos_; + mutable Arc arc_; +}; + +template +inline void ComplementFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +template +inline void ComplementFst::InitArcIterator(StateId s, + ArcIteratorData *data) const { + data->base = new ArcIterator>(*this, s); +} + +// Useful alias when using StdArc. +using StdComplementFst = ComplementFst; + +} // namespace fst + +#endif // FST_COMPLEMENT_H_ diff --git a/projects/llm_framework/include/fst/compose-filter.h b/projects/llm_framework/include/fst/compose-filter.h new file mode 100644 index 00000000..7251e273 --- /dev/null +++ b/projects/llm_framework/include/fst/compose-filter.h @@ -0,0 +1,571 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for filtering the composition matches, e.g. for correct epsilon +// handling. + +#ifndef FST_COMPOSE_FILTER_H_ +#define FST_COMPOSE_FILTER_H_ + +#include +#include // For optional argument declarations +#include +#include + + +namespace fst { + +// Composition filters determine which matches are allowed to proceed. The +// filter's state is represeted by the type ComposeFilter::FilterState. +// The basic filters handle correct epsilon matching. Their interface is: +// +// template +// class ComposeFilter { +// public: +// using Matcher1 = ...; +// using Matcher2 = ...; +// using FST1 = typename M1::FST; +// using FST2 = typename M2::FST; +// using FilterState = ...; +// +// using Arc = typename FST1::Arc; +// using StateId = typename Arc::StateId; +// using Weight = typename Arc::Weight; +// +// // Required constructor. +// ComposeFilter(const FST1 &fst1, const FST2 &fst2, +// M1 *matcher1 = nullptr, M2 *matcher2 = nullptr); +// +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// ComposeFilter(const ComposeFilter &filter, +// bool safe = false); +// +// // Return start state of filter. +// FilterState Start() const; +// +// // Specifies current composition state. +// void SetState(StateId s1, StateId s2, const FilterState &fs); +// +// // Apply filter at current composition state to these transitions. If an +// // arc label to be matched is kNolabel, then that side does not consume a +// // symbol. Returns the new filter state or, if disallowed, +// // FilterState::NoState(). The filter is permitted to modify its inputs +// // (e.g. for optimization reasons). +// FilterState FilterArc(Arc *arc1, Arc *arc2) const; + +// // Apply filter at current composition state to these final weights +// // (cf. superfinal transitions). The filter may modify its inputs +// // (e.g. for optimization reasons). +// void FilterFinal(Weight *w1, Weight *w2) const; +// +// // Return the respective matchers. Ownership stays with filter. These +// // methods allow the filter to access and possibly modify the compositio +// // matchers (useful, e.g., with lookahead). +// +// Matcher1 *GetMatcher1(); +// +// Matcher2 *GetMatcher2(); +// +// // This specifies how the filter affects the composition result properties. +// It takes as argument the properties that would apply with a trivial +// // composition filter. +// uint64 Properties(uint64 props) const; +// }; +// +// This filter allows only exact matching of symbols from FST1 with on FST2; +// e.g., no special interpretation of epsilons. +template +class NullComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + NullComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + NullComposeFilter(const NullComposeFilter &filter, bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + return (arc1->olabel == kNoLabel || arc2->ilabel == kNoLabel) + ? FilterState::NoState() + : FilterState(true); + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter allows all epsilon matches, potentially resulting in redundant +// epsilon paths. The use of this filter gives correct results iff one of the +// following conditions hold: +// +// (1) The semiring is idempotent, +// (2) the first FST is output-epsilon free, or +// (3) the second FST is input-epsilon free. +// +// For (1), redundant epsilon paths may be created but won't hurt correctness. +// For (2) and (3), no redundant paths are created. +template +class TrivialComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + TrivialComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + TrivialComposeFilter(const TrivialComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *, Arc *) const { return FilterState(true); } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter requires epsilons on FST1 to be read before epsilons on FST2. +template +class SequenceComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + SequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + SequenceComposeFilter(const SequenceComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + const auto na1 = internal::NumArcs(fst1_, s1); + const auto ne1 = internal::NumOutputEpsilons(fst1_, s1); + const bool fin1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !fin1; + noeps1_ = ne1 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc1->olabel == kNoLabel) { + return alleps1_ ? FilterState::NoState() : noeps1_ ? FilterState(0) + : FilterState(1); + } else if (arc2->ilabel == kNoLabel) { + return fs_ != FilterState(0) ? FilterState::NoState() : FilterState(0); + } else { + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state. + bool alleps1_; // Only epsilons (and non-final) leaving s1_? + bool noeps1_; // No epsilons leaving s1_? +}; + +// This filter requires epsilons on FST2 to be read before epsilons on FST1. +template +class AltSequenceComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + AltSequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + AltSequenceComposeFilter( + const AltSequenceComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + const auto na2 = internal::NumArcs(fst2_, s2); + const auto ne2 = internal::NumInputEpsilons(fst2_, s2); + const bool fin2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !fin2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) { + return alleps2_ ? FilterState::NoState() : noeps2_ ? FilterState(0) + : FilterState(1); + } else if (arc1->olabel == kNoLabel) { + return fs_ == FilterState(1) ? FilterState::NoState() : FilterState(0); + } else { + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state. + bool alleps2_; // Only epsilons (and non-final) leaving s2_? + bool noeps2_; // No epsilons leaving s2_? +}; + +// This filter requires epsilons on FST1 to be matched with epsilons on FST2 +// whenever possible. (Template arg default declared in fst-decl.h.) +template +class MatchComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MatchComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + MatchComposeFilter(const MatchComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + size_t na1 = internal::NumArcs(fst1_, s1); + size_t ne1 = internal::NumOutputEpsilons(fst1_, s1); + bool f1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !f1; + noeps1_ = ne1 == 0; + size_t na2 = internal::NumArcs(fst2_, s2); + size_t ne2 = internal::NumInputEpsilons(fst2_, s2); + bool f2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !f2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) { // Epsilon in FST1. + return fs_ == FilterState(0) + ? (noeps2_ + ? FilterState(0) + : (alleps2_ ? FilterState::NoState() : FilterState(1))) + : (fs_ == FilterState(1) ? FilterState(1) + : FilterState::NoState()); + } else if (arc1->olabel == kNoLabel) { // Epsilon in FST2. + return fs_ == FilterState(0) + ? (noeps1_ + ? FilterState(0) + : (alleps1_ ? FilterState::NoState() : FilterState(2))) + : (fs_ == FilterState(2) ? FilterState(2) + : FilterState::NoState()); + } else if (arc1->olabel == 0) { // Epsilon in both. + return fs_ == FilterState(0) ? FilterState(0) : FilterState::NoState(); + } else { // Both are non-epsilons. + return FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state ID. + bool alleps1_; // Only epsilson (and non-final) leaving s1? + bool alleps2_; // Only epsilons (and non-final) leaving s2? + bool noeps1_; // No epsilons leaving s1? + bool noeps2_; // No epsilons leaving s2? +}; + +// This filter disallows matching epsilons on FST1 with epsilons on FST2, +// but allows all other matches, potentially resulting in redundant +// epsilon paths. The use of this filter gives correct results iff one of the +// following conditions hold: +// +// (1) The semiring is idempotent, +// (2) the first FST is output-epsilon free, or +// (3) the second FST is input-epsilon free. +// +// For (1), redundant epsilon paths may be created but won't hurt correctness. +// For (2) and (3), no redundant paths are created. +template +class NoMatchComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + NoMatchComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + NoMatchComposeFilter(const NoMatchComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + return FilterState(arc1->olabel != 0 || arc2->ilabel != 0); + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter works with the MultiEpsMatcher to determine if multi-epsilons are +// preserved in the composition output (rather than rewritten as 0) and +// ensures correct properties. +template +class MultiEpsFilter { + public: + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using FilterState = typename Filter::FilterState; + + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MultiEpsFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr, + bool keep_multi_eps = false) + : filter_(fst1, fst2, matcher1, matcher2), + keep_multi_eps_(keep_multi_eps) {} + + MultiEpsFilter(const MultiEpsFilter &filter, bool safe = false) + : filter_(filter.filter_, safe), + keep_multi_eps_(filter.keep_multi_eps_) {} + + FilterState Start() const { return filter_.Start(); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + return filter_.SetState(s1, s2, fs); + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + const auto fs = filter_.FilterArc(arc1, arc2); + if (keep_multi_eps_) { + if (arc1->olabel == kNoLabel) arc1->ilabel = arc2->ilabel; + if (arc2->ilabel == kNoLabel) arc2->olabel = arc1->olabel; + } + return fs; + } + + void FilterFinal(Weight *w1, Weight *w2) const { + return filter_.FilterFinal(w1, w2); + } + + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + const auto oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + Filter filter_; + bool keep_multi_eps_; +}; + +} // namespace fst + +#endif // FST_COMPOSE_FILTER_H_ diff --git a/projects/llm_framework/include/fst/compose.h b/projects/llm_framework/include/fst/compose.h new file mode 100644 index 00000000..1066d097 --- /dev/null +++ b/projects/llm_framework/include/fst/compose.h @@ -0,0 +1,1035 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute the composition of two FSTs. + +#ifndef FST_COMPOSE_H_ +#define FST_COMPOSE_H_ + +#include +#include +#include + +#include + +#include +#include +#include // For optional argument declarations +#include +#include +#include +#include + + +namespace fst { + +// Delayed composition options templated on the arc type, the matcher, +// the composition filter, and the composition state table. By +// default, the matchers, filter, and state table are constructed by +// composition. If set below, the user can instead pass in these +// objects; in that case, ComposeFst takes their ownership. This +// version controls composition implemented between generic Fst +// types and a shared matcher type M for Fst. This should be +// adequate for most applications, giving a reasonable tradeoff +// between efficiency and code sharing (but see ComposeFstImplOptions). +template >, + class Filter = SequenceComposeFilter, + class StateTable = + GenericComposeStateTable> +struct ComposeFstOptions : public CacheOptions { + M *matcher1; // FST1 matcher. + M *matcher2; // FST2 matcher. + Filter *filter; // Composition filter. + StateTable *state_table; // Composition state table. + + explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(), + M *matcher1 = nullptr, M *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table) {} +}; + +// Forward declaration of ComposeFstMatcher. +template +class ComposeFstMatcher; + +// Delayed composition options templated on the two matcher types, the +// composition filter, the composition state table and the cache store. By +// default, the matchers, filter, state table and cache store are constructed +// by composition. If set below, the user can instead pass in these objects; in +// that case, ComposeFst takes their ownership. This version controls +// composition implemented using arbitrary matchers (of the same arc type but +// otherwise arbitrary FST type). The user must ensure the matchers are +// compatible. These options permit the most efficient use, but shares the +// least code. This is for advanced use only in the most demanding or +// specialized applications that can benefit from it; otherwise, prefer +// ComposeFstOptions). +template , + class StateTable = GenericComposeStateTable< + typename M1::Arc, typename Filter::FilterState>, + class CacheStore = DefaultCacheStore> +struct ComposeFstImplOptions : public CacheImplOptions { + M1 *matcher1; // FST1 matcher (see matcher.h).... + M2 *matcher2; // FST2 matcher. + Filter *filter; // Composition filter (see compose-filter.h). + StateTable + *state_table; // Composition state table (see compose-state-table.h). + bool own_state_table; // ComposeFstImpl takes ownership of 'state_table'? + bool allow_noncommute; // Allow non-commutative weights + + explicit ComposeFstImplOptions(const CacheOptions &opts, + M1 *matcher1 = nullptr, M2 *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheImplOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table), + own_state_table(true), + allow_noncommute(false) {} + + explicit ComposeFstImplOptions(const CacheImplOptions &opts, + M1 *matcher1 = nullptr, M2 *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheImplOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table), + own_state_table(true), + allow_noncommute(false) {} + + ComposeFstImplOptions() + : matcher1(nullptr), + matcher2(nullptr), + filter(nullptr), + state_table(nullptr), + own_state_table(true), + allow_noncommute(false) {} +}; + +namespace internal { + +// Implementation of delayed composition. This base class is common to the +// variants with different matchers, composition filters and state tables. +template , + class F = ComposeFst> +class ComposeFstImplBase + : public CacheBaseImpl { + public: + using FST = F; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using State = typename CacheStore::State; + using CacheImpl = CacheBaseImpl; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::HasStart; + using CacheImpl::HasFinal; + using CacheImpl::HasArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + ComposeFstImplBase(const CacheImplOptions &opts) + : CacheImpl(opts) {} + + ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {} + + ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) { + SetType(impl.Type()); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual ComposeFstImplBase *Copy() const = 0; + + ~ComposeFstImplBase() override {} + + StateId Start() { + if (!HasStart()) { + const auto start = ComputeStart(); + if (start != kNoStateId) SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, ComputeFinal(s)); + return CacheImpl::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + virtual MatcherBase *InitMatcher(const F &fst, + MatchType match_type) const { + // Use the default matcher if no override is provided. + return nullptr; + } + + protected: + virtual StateId ComputeStart() = 0; + virtual Weight ComputeFinal(StateId s) = 0; +}; + +// Implementation of delayed composition templated on the matchers (see +// matcher.h), composition filter (see compose-filter.h) and the composition +// state table (see compose-state-table.h). +template +class ComposeFstImpl + : public ComposeFstImplBase { + public: + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using FST1 = typename Matcher1::FST; + using FST2 = typename Matcher2::FST; + + using Arc = typename CacheStore::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = typename Filter::FilterState; + using State = typename CacheStore::State; + + using CacheImpl = CacheBaseImpl; + + using StateTuple = typename StateTable::StateTuple; + + friend class ComposeFstMatcher; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + + template + ComposeFstImpl(const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions &opts); + + ComposeFstImpl(const ComposeFstImpl &impl) + : ComposeFstImplBase(impl), + filter_(new Filter(*impl.filter_, true)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(new StateTable(*impl.state_table_)), + own_state_table_(true), + match_type_(impl.match_type_) {} + + ~ComposeFstImpl() override { + if (own_state_table_) delete state_table_; + } + + ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) || + (matcher1_->Properties(0) & kError) || + (matcher2_->Properties(0) & kError) | + (filter_->Properties(0) & kError) || + state_table_->Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + // Arranges it so that the first arg to OrderedExpand is the Fst + // that will be matched on. + void Expand(StateId s) override { + const auto &tuple = state_table_->Tuple(s); + const auto s1 = tuple.StateId1(); + const auto s2 = tuple.StateId2(); + filter_->SetState(s1, s2, tuple.GetFilterState()); + if (MatchInput(s1, s2)) { + OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true); + } else { + OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false); + } + } + + const FST1 &GetFst1() const { return fst1_; } + + const FST2 &GetFst2() const { return fst2_; } + + const Matcher1 *GetMatcher1() const { return matcher1_; } + + Matcher1 *GetMatcher1() { return matcher1_; } + + const Matcher2 *GetMatcher2() const { return matcher2_; } + + Matcher2 *GetMatcher2() { return matcher2_; } + + const Filter *GetFilter() const { return filter_.get(); } + + Filter *GetFilter() { return filter_.get(); } + + const StateTable *GetStateTable() const { return state_table_; } + + StateTable *GetStateTable() { return state_table_; } + + MatcherBase *InitMatcher(const ComposeFst &fst, + MatchType match_type) const override { + const auto test_props = match_type == MATCH_INPUT + ? kFstProperties & ~kILabelInvariantProperties + : kFstProperties & ~kOLabelInvariantProperties; + // If both matchers support 'match_type' and we have a guarantee that a + // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of + // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then + // ComposeFstMatcher can be used. + if ((matcher1_->Type(false) == match_type) && + (matcher2_->Type(false) == match_type) && + (filter_->Properties(test_props) == test_props)) { + return new ComposeFstMatcher< + CacheStore, Filter, StateTable>(&fst, match_type); + } + return nullptr; + } + + private: + // This does that actual matching of labels in the composition. The + // arguments are ordered so matching is called on state 'sa' of + // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg + // determines whether the input or output label of arcs at 'sb' is + // the one to match on. + template + void OrderedExpand(StateId s, const Fst &, StateId sa, const FST &fstb, + StateId sb, Matcher *matchera, bool match_input) { + matchera->SetState(sa); + // First processes non-consuming symbols (e.g., epsilons) on FSTA. + const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0, + Weight::One(), sb); + MatchArc(s, matchera, loop, match_input); + // Then processes matches on FSTB. + for (ArcIterator iterb(fstb, sb); !iterb.Done(); iterb.Next()) { + MatchArc(s, matchera, iterb.Value(), match_input); + } + CacheImpl::SetArcs(s); + } + + // Matches a single transition from 'fstb' against 'fata' at 's'. + template + void MatchArc(StateId s, Matcher *matchera, const Arc &arc, + bool match_input) { + if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) { + for (; !matchera->Done(); matchera->Next()) { + auto arca = matchera->Value(); + auto arcb = arc; + if (match_input) { + const auto &fs = filter_->FilterArc(&arcb, &arca); + if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs); + } else { + const auto &fs = filter_->FilterArc(&arca, &arcb); + if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs); + } + } + } + } + + // Add a matching transition at 's'. + void AddArc(StateId s, const Arc &arc1, const Arc &arc2, + const FilterState &f) { + const StateTuple tuple(arc1.nextstate, arc2.nextstate, f); + CacheImpl::EmplaceArc( + s, arc1.ilabel, arc2.olabel, Times(arc1.weight, arc2.weight), + state_table_->FindState(tuple)); + } + + StateId ComputeStart() override { + const auto s1 = fst1_.Start(); + if (s1 == kNoStateId) return kNoStateId; + const auto s2 = fst2_.Start(); + if (s2 == kNoStateId) return kNoStateId; + const auto &fs = filter_->Start(); + const StateTuple tuple(s1, s2, fs); + return state_table_->FindState(tuple); + } + + Weight ComputeFinal(StateId s) override { + const auto &tuple = state_table_->Tuple(s); + const auto s1 = tuple.StateId1(); + auto final1 = matcher1_->Final(s1); + if (final1 == Weight::Zero()) return final1; + const auto s2 = tuple.StateId2(); + auto final2 = matcher2_->Final(s2); + if (final2 == Weight::Zero()) return final2; + filter_->SetState(s1, s2, tuple.GetFilterState()); + filter_->FilterFinal(&final1, &final2); + return Times(final1, final2); + } + + // Determines which side to match on per composition state. + bool MatchInput(StateId s1, StateId s2) { + switch (match_type_) { + case MATCH_INPUT: + return true; + case MATCH_OUTPUT: + return false; + default: // MATCH_BOTH + const auto priority1 = matcher1_->Priority(s1); + const auto priority2 = matcher2_->Priority(s2); + if (priority1 == kRequirePriority && priority2 == kRequirePriority) { + FSTERROR() << "ComposeFst: Both sides can't require match"; + SetProperties(kError, kError); + return true; + } + if (priority1 == kRequirePriority) return false; + if (priority2 == kRequirePriority) { + return true; + } + return priority1 <= priority2; + } + } + + // Identifies and verifies the capabilities of the matcher to be used for + // composition. + void SetMatchType(); + + std::unique_ptr filter_; + Matcher1 *matcher1_; // Borrowed reference. + Matcher2 *matcher2_; // Borrowed reference. + const FST1 &fst1_; + const FST2 &fst2_; + StateTable *state_table_; + bool own_state_table_; + + MatchType match_type_; +}; + +template +template +ComposeFstImpl::ComposeFstImpl( + const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions &opts) + : ComposeFstImplBase(opts), + filter_(opts.filter + ? opts.filter + : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(opts.state_table ? opts.state_table + : new StateTable(fst1_, fst2_)), + own_state_table_(opts.state_table ? opts.own_state_table : true) { + SetType("compose"); + if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) { + FSTERROR() << "ComposeFst: Output symbol table of 1st argument " + << "does not match input symbol table of 2nd argument"; + SetProperties(kError, kError); + } + SetInputSymbols(fst1_.InputSymbols()); + SetOutputSymbols(fst2_.OutputSymbols()); + SetMatchType(); + VLOG(2) << "ComposeFstImpl: Match type: " << match_type_; + if (match_type_ == MATCH_NONE) SetProperties(kError, kError); + const auto fprops1 = fst1.Properties(kFstProperties, false); + const auto fprops2 = fst2.Properties(kFstProperties, false); + const auto mprops1 = matcher1_->Properties(fprops1); + const auto mprops2 = matcher2_->Properties(fprops2); + const auto cprops = ComposeProperties(mprops1, mprops2); + SetProperties(filter_->Properties(cprops), kCopyProperties); + if (state_table_->Error()) SetProperties(kError, kError); +} + +template +void ComposeFstImpl::SetMatchType() { + // Ensures any required matching is possible and known. + if ((matcher1_->Flags() & kRequireMatch) && + matcher1_->Type(true) != MATCH_OUTPUT) { + FSTERROR() << "ComposeFst: 1st argument cannot perform required matching " + << "(sort?)."; + match_type_ = MATCH_NONE; + return; + } + if ((matcher2_->Flags() & kRequireMatch) && + matcher2_->Type(true) != MATCH_INPUT) { + FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching " + << "(sort?)."; + match_type_ = MATCH_NONE; + return; + } + // Finds which sides to match on (favoring minimal testing of capabilities). + const auto type1 = matcher1_->Type(false); + const auto type2 = matcher2_->Type(false); + if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + match_type_ = MATCH_BOTH; + } else if (type1 == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (type2 == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (matcher1_->Type(true) == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (matcher2_->Type(true) == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else { + FSTERROR() << "ComposeFst: 1st argument cannot match on output labels " + << "and 2nd argument cannot match on input labels (sort?)."; + match_type_ = MATCH_NONE; + } +} + +} // namespace internal + +// Computes the composition of two transducers. This version is a delayed FST. +// If FST1 transduces string x to y with weight a and FST2 transduces y to z +// with weight b, then their composition transduces string x to z with weight +// Times(x, z). +// +// The output labels of the first transducer or the input labels of the second +// transducer must be sorted (with the default matcher). The weights need to +// form a commutative semiring (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Assuming the first FST is unsorted and the second is sorted, +// +// Time: O(v1 v2 d1 (log d2 + m2)), +// Space: O(v1 v2) +// +// where vi = # of states visited, di = maximum out-degree, and mi the +// maximum multiplicity of the states visited, for the ith FST. Constant time +// and space to visit an input state or arc is assumed and exclusive of caching. +// +// Caveats: +// - ComposeFst does not trim its output (since it is a delayed operation). +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which transducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. The CacheStore specifies the +// cache store (default declared in fst-decl.h). +template */> +class ComposeFst + : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = CacheStore; + using State = typename CacheStore::State; + + using Impl = internal::ComposeFstImplBase; + + friend class ArcIterator>; + friend class StateIterator>; + template friend class ComposeFstMatcher; + + // Compose specifying only caching options. + ComposeFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ImplToFst(CreateBase(fst1, fst2, opts)) {} + + // Compose specifying one shared matcher type M. Requires that the input FSTs + // and matcher FST types be Fst. Recommended for best code-sharing and + // matcher compatiblity. + template + ComposeFst(const Fst &fst1, const Fst &fst2, + const ComposeFstOptions &opts) + : ImplToFst(CreateBase1(fst1, fst2, opts)) {} + + // Compose specifying two matcher types Matcher1 and Matcher2. Requires input + // FST (of the same Arc type, but o.w. arbitrary) match the corresponding + // matcher FST types). Recommended only for advanced use in demanding or + // specialized applications due to potential code bloat and matcher + // incompatibilities. + template + ComposeFst(const typename Matcher1::FST &fst1, + const typename Matcher2::FST &fst2, + const ComposeFstImplOptions &opts) + : ImplToFst(CreateBase2(fst1, fst2, opts)) {} + + // See Fst<>::Copy() for doc. + ComposeFst(const ComposeFst &fst, bool safe = false) + : ImplToFst(safe ? std::shared_ptr(fst.GetImpl()->Copy()) + : fst.GetSharedImpl()) {} + + // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc. + ComposeFst *Copy(bool safe = false) const override { + return new ComposeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return GetImpl()->InitMatcher(*this, match_type); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit ComposeFst(std::shared_ptr impl) : ImplToFst(impl) {} + + // Create compose implementation specifying two matcher types. + template + static std::shared_ptr CreateBase2( + const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, + const ComposeFstImplOptions &opts) { + auto impl = std::make_shared< + internal::ComposeFstImpl>(fst1, fst2, + opts); + if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) { + const auto props1 = fst1.Properties(kUnweighted, true); + const auto props2 = fst2.Properties(kUnweighted, true); + if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) { + FSTERROR() << "ComposeFst: Weights must be a commutative semiring: " + << Weight::Type(); + impl->SetProperties(kError, kError); + } + } + return impl; + } + + // Create compose implementation specifying one matcher type; requires that + // input and matcher FST types be Fst. + template + static std::shared_ptr CreateBase1( + const Fst &fst1, const Fst &fst2, + const ComposeFstOptions &opts) { + ComposeFstImplOptions + nopts(opts, opts.matcher1, opts.matcher2, opts.filter, + opts.state_table); + return CreateBase2(fst1, fst2, nopts); + } + + // Create compose implementation specifying no matcher type. + static std::shared_ptr CreateBase(const Fst &fst1, + const Fst &fst2, + const CacheOptions &opts) { + switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers + default: + case MATCH_NONE: { // Default composition (no look-ahead). + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_OUTPUT: { // Lookahead on fst1. + using M = typename DefaultLookAhead::FstMatcher; + using F = typename DefaultLookAhead::ComposeFilter; + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_INPUT: { // Lookahead on fst2 + using M = typename DefaultLookAhead::FstMatcher; + using F = typename DefaultLookAhead::ComposeFilter; + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + } + } + + private: + ComposeFst &operator=(const ComposeFst &fst) = delete; +}; + +// Specialization for ComposeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const ComposeFst &fst) + : CacheStateIterator>(fst, + fst.GetMutableImpl()) {} +}; + +// Specialization for ComposeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ComposeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void ComposeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT, +// iff the underlying matchers for the two FSTS being composed support +// MATCH_INPUT or MATCH_OUTPUT, respectively. +template +class ComposeFstMatcher : public MatcherBase { + public: + using Arc = typename CacheStore::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + using FilterState = typename Filter::FilterState; + + using StateTuple = typename StateTable::StateTuple; + using Impl = internal::ComposeFstImpl; + + // The compose FST arg must match the filter and state table types. + // This makes a copy of the FST. + ComposeFstMatcher(const ComposeFst &fst, + MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(match_type), + matcher1_(impl_->matcher1_->Copy()), + matcher2_(impl_->matcher2_->Copy()), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + // The compose FST arg must match the filter and state table types. + // This doesn't copy the FST (although it may copy components). + ComposeFstMatcher(const ComposeFst *fst, + MatchType match_type) + : fst_(*fst), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(match_type), + matcher1_(impl_->matcher1_->Copy()), + matcher2_(impl_->matcher2_->Copy()), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + // This makes a copy of the FST. + ComposeFstMatcher( + const ComposeFstMatcher &matcher, + bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(matcher.match_type_), + matcher1_(matcher.matcher1_->Copy(safe)), + matcher2_(matcher.matcher2_->Copy(safe)), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + ComposeFstMatcher *Copy( + bool safe = false) const override { + return new ComposeFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if ((matcher1_->Type(test) == MATCH_NONE) || + (matcher2_->Type(test) == MATCH_NONE)) { + return MATCH_NONE; + } + if (((matcher1_->Type(test) == MATCH_UNKNOWN) && + (matcher2_->Type(test) == MATCH_UNKNOWN)) || + ((matcher1_->Type(test) == MATCH_UNKNOWN) && + (matcher2_->Type(test) == match_type_)) || + ((matcher1_->Type(test) == match_type_) && + (matcher2_->Type(test) == MATCH_UNKNOWN))) { + return MATCH_UNKNOWN; + } + if ((matcher1_->Type(test) == match_type_) && + (matcher2_->Type(test) == match_type_)) { + return match_type_; + } + return MATCH_NONE; + } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops; + } + + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + const auto &tuple = impl_->state_table_->Tuple(s); + matcher1_->SetState(tuple.StateId1()); + matcher2_->SetState(tuple.StateId2()); + loop_.nextstate = s_; + } + + bool Find(Label label) final { + bool found = false; + current_loop_ = false; + if (label == 0) { + current_loop_ = true; + found = true; + } + if (match_type_ == MATCH_INPUT) { + found = found || FindLabel(label, matcher1_.get(), matcher2_.get()); + } else { // match_type_ == MATCH_OUTPUT + found = found || FindLabel(label, matcher2_.get(), matcher1_.get()); + } + return found; + } + + bool Done() const final { + return !current_loop_ && matcher1_->Done() && matcher2_->Done(); + } + + const Arc &Value() const final { return current_loop_ ? loop_ : arc_; } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else if (match_type_ == MATCH_INPUT) { + FindNext(matcher1_.get(), matcher2_.get()); + } else { // match_type_ == MATCH_OUTPUT + FindNext(matcher2_.get(), matcher1_.get()); + } + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + // Processes a match with the filter and creates resulting arc. + bool MatchArc(StateId s, Arc arc1, + Arc arc2) { // FIXME(kbg): copy but not assignment. + const auto &fs = impl_->filter_->FilterArc(&arc1, &arc2); + if (fs == FilterState::NoState()) return false; + const StateTuple tuple(arc1.nextstate, arc2.nextstate, fs); + arc_.ilabel = arc1.ilabel; + arc_.olabel = arc2.olabel; + arc_.weight = Times(arc1.weight, arc2.weight); + arc_.nextstate = impl_->state_table_->FindState(tuple); + return true; + } + + // Finds the first match allowed by the filter. + template + bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) { + if (matchera->Find(label)) { + matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel + : matchera->Value().ilabel); + return FindNext(matchera, matcherb); + } + return false; + } + + // Finds the next match allowed by the filter, returning true iff such a + // match is found. + template + bool FindNext(MatcherA *matchera, MatcherB *matcherb) { + // State when entering this function: + // 'matchera' is pointed to a match x, y for label x, and a match for y was + // requested on 'matcherb'. + while (!matchera->Done() || !matcherb->Done()) { + if (matcherb->Done()) { + // If no more matches for y on 'matcherb', moves forward on 'matchera' + // until a match x, y' is found such that there is a match for y' on + // 'matcherb'. + matchera->Next(); + while (!matchera->Done() && + !matcherb->Find(match_type_ == MATCH_INPUT + ? matchera->Value().olabel + : matchera->Value().ilabel)) { + matchera->Next(); + } + } + while (!matcherb->Done()) { + // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is + // pointing to a match y', z' ('arcb'). If combining these two arcs is + // allowed by the filter (hence resulting in an arc x, z') return true. + // Position 'matcherb' on the next potential match for y' before + // returning. + const auto &arca = matchera->Value(); + const auto &arcb = matcherb->Value(); + // Position 'matcherb' on the next potential match for y'. + matcherb->Next(); + // Returns true If combining these two arcs is allowed by the filter + // (hence resulting in an arc x, z'); otherwise consider next match + // for y' on 'matcherb'. + if (MatchArc(s_, match_type_ == MATCH_INPUT ? arca : arcb, + match_type_ == MATCH_INPUT ? arcb : arca)) { + return true; + } + } + } + // Both 'matchera' and 'matcherb' are done, no more match to analyse. + return false; + } + + std::unique_ptr> owned_fst_; + const ComposeFst &fst_; + const Impl *impl_; + StateId s_; + MatchType match_type_; + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + bool current_loop_; + Arc loop_; + Arc arc_; +}; + +// Useful alias when using StdArc. +using StdComposeFst = ComposeFst; + +enum ComposeFilter { + AUTO_FILTER, + NULL_FILTER, + TRIVIAL_FILTER, + SEQUENCE_FILTER, + ALT_SEQUENCE_FILTER, + MATCH_FILTER, + NO_MATCH_FILTER +}; + +struct ComposeOptions { + bool connect; // Connect output? + ComposeFilter filter_type; // Pre-defined filter to use. + + explicit ComposeOptions(bool connect = true, + ComposeFilter filter_type = AUTO_FILTER) + : connect(connect), filter_type(filter_type) {} +}; + +// Computes the composition of two transducers. This version writes +// the composed FST into a MutableFst. If FST1 transduces string x to +// y with weight a and FST2 transduces y to z with weight b, then +// their composition transduces string x to z with weight +// Times(x, z). +// +// The output labels of the first transducer or the input labels of +// the second transducer must be sorted. The weights need to form a +// commutative semiring (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Assuming the first FST is unsorted and the second is sorted: +// +// Time: O(V1 V2 D1 (log D2 + M2)), +// Space: O(V1 V2 D1 M2) +// +// where Vi = # of states, Di = maximum out-degree, and Mi is the maximum +// multiplicity, for the ith FST. +// +// Caveats: +// +// - Compose trims its output. +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which transducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +template +void Compose(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const ComposeOptions &opts = ComposeOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, nopts); + break; + } + case NULL_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case SEQUENCE_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case ALT_SEQUENCE_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case MATCH_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case NO_MATCH_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case TRIVIAL_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/concat.h b/projects/llm_framework/include/fst/concat.h new file mode 100644 index 00000000..74d22c22 --- /dev/null +++ b/projects/llm_framework/include/fst/concat.h @@ -0,0 +1,220 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the concatenation of two FSTs. + +#ifndef FST_CONCAT_H_ +#define FST_CONCAT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Computes the concatenation (product) of two FSTs. If FST1 transduces string +// x to y with weight a and FST2 transduces string w to v with weight b, then +// their concatenation transduces string xw to yv with weight Times(a, b). +// +// This version modifies its MutableFst argument (in first position). +// +// Complexity: +// +// Time: O(V1 + V2 + E2) +// Space: O(V1 + V2 + E2) +// +// where Vi is the number of states, and Ei is the number of arcs, of the ith +// FST. +template +void Concat(MutableFst *fst1, const Fst &fst2) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Concat: Input/output symbol tables of 1st argument " + << "does not match input/output symbol tables of 2nd argument"; + fst1->SetProperties(kError, kError); + return; + } + const auto props1 = fst1->Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + const auto start1 = fst1->Start(); + if (start1 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + const auto numstates1 = fst1->NumStates(); + if (fst2.Properties(kExpanded, false)) { + fst1->ReserveStates(numstates1 + CountStates(fst2)); + } + for (StateIterator> siter2(fst2); !siter2.Done(); siter2.Next()) { + const auto s1 = fst1->AddState(); + const auto s2 = siter2.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator> aiter(fst2, s2); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate += numstates1; + fst1->AddArc(s1, arc); + } + } + const auto start2 = fst2.Start(); + for (StateId s1 = 0; s1 < numstates1; ++s1) { + const auto weight = fst1->Final(s1); + if (weight != Weight::Zero()) { + fst1->SetFinal(s1, Weight::Zero()); + if (start2 != kNoStateId) { + fst1->AddArc(s1, Arc(0, 0, weight, start2 + numstates1)); + } + } + } + if (start2 != kNoStateId) { + fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties); + } +} + +// Computes the concatentation of two FSTs. This version modifies its +// MutableFst argument (in second position). +// +// Complexity: +// +// Time: O(V1 + E1) +// Space: O(V1 + E1) +// +// where Vi is the number of states, and Ei is the number of arcs, of the ith +// FST. +template +void Concat(const Fst &fst1, MutableFst *fst2) { + using Weight = typename Arc::Weight; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) { + FSTERROR() << "Concat: Input/output symbol tables of 1st argument " + << "does not match input/output symbol tables of 2nd argument"; + fst2->SetProperties(kError, kError); + return; + } + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2->Properties(kFstProperties, false); + const auto start2 = fst2->Start(); + if (start2 == kNoStateId) { + if (props1 & kError) fst2->SetProperties(kError, kError); + return; + } + const auto numstates2 = fst2->NumStates(); + if (fst1.Properties(kExpanded, false)) { + fst2->ReserveStates(numstates2 + CountStates(fst1)); + } + for (StateIterator> siter(fst1); !siter.Done(); siter.Next()) { + const auto s1 = siter.Value(); + const auto s2 = fst2->AddState(); + const auto weight = fst1.Final(s1); + if (weight != Weight::Zero()) { + fst2->ReserveArcs(s2, fst1.NumArcs(s1) + 1); + fst2->AddArc(s2, Arc(0, 0, weight, start2)); + } else { + fst2->ReserveArcs(s2, fst1.NumArcs(s1)); + } + for (ArcIterator> aiter(fst1, s1); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate += numstates2; + fst2->AddArc(s2, arc); + } + } + const auto start1 = fst1.Start(); + if (start1 != kNoStateId) { + fst2->SetStart(start1 + numstates2); + fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties); + } else { + fst2->SetStart(fst2->AddState()); + } +} + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in first position). +template +void Concat(RationalFst *fst1, const Fst &fst2) { + fst1->GetMutableImpl()->AddConcat(fst2, true); +} + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in second position). +template +void Concat(const Fst &fst1, RationalFst *fst2) { + fst2->GetMutableImpl()->AddConcat(fst1, false); +} + +using ConcatFstOptions = RationalFstOptions; + +// Computes the concatenation (product) of two FSTs; this version is a delayed +// FST. If FST1 transduces string x to y with weight a and FST2 transduces +// string w to v with weight b, then their concatenation transduces string xw +// to yv with Times(a, b). +// +// Complexity: +// +// Time: O(v1 + e1 + v2 + e2), +// Space: O(v1 + v2) +// +// where vi is the number of states visited, and ei is the number of arcs +// visited, of the ith FST. Constant time and space to visit an input state or +// arc is assumed and exclusive of caching. +template +class ConcatFst : public RationalFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + ConcatFst(const Fst &fst1, const Fst &fst2) { + GetMutableImpl()->InitConcat(fst1, fst2); + } + + ConcatFst(const Fst &fst1, const Fst &fst2, + const ConcatFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitConcat(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + ConcatFst(const ConcatFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc. + ConcatFst *Copy(bool safe = false) const override { + return new ConcatFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for ConcatFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const ConcatFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for ConcatFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ConcatFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdConcatFst = ConcatFst; + +} // namespace fst + +#endif // FST_CONCAT_H_ diff --git a/projects/llm_framework/include/fst/config.h b/projects/llm_framework/include/fst/config.h new file mode 100644 index 00000000..32f2a653 --- /dev/null +++ b/projects/llm_framework/include/fst/config.h @@ -0,0 +1,3 @@ +// Windows-specific OpenFst config file +// No dynamic registration. +#define FST_NO_DYNAMIC_LINKING 1 diff --git a/projects/llm_framework/include/fst/config.h.in b/projects/llm_framework/include/fst/config.h.in new file mode 100644 index 00000000..7815dfcd --- /dev/null +++ b/projects/llm_framework/include/fst/config.h.in @@ -0,0 +1,11 @@ +// OpenFst config file + +/* Define to 1 if you have the ICU library. */ +#undef HAVE_ICU + +/* Define to 1 if the system has the type `std::tr1::hash'. */ +#define HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ 1 + +/* Define to 1 if the system has the type `__gnu_cxx::slist'. */ +#define HAVE___GNU_CXX__SLIST_INT_ 1 diff --git a/projects/llm_framework/include/fst/connect.h b/projects/llm_framework/include/fst/connect.h new file mode 100644 index 00000000..4c989292 --- /dev/null +++ b/projects/llm_framework/include/fst/connect.h @@ -0,0 +1,323 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions to remove unsuccessful paths from an FST. + +#ifndef FST_CONNECT_H_ +#define FST_CONNECT_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +// Finds and returns connected components. Use with Visit(). +template +class CcVisitor { + public: + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + // cc[i]: connected component number for state i. + explicit CcVisitor(std::vector *cc) + : comps_(new UnionFind(0, kNoStateId)), cc_(cc), nstates_(0) {} + + // comps: connected components equiv classes. + explicit CcVisitor(UnionFind *comps) + : comps_(comps), cc_(nullptr), nstates_(0) {} + + ~CcVisitor() { + if (cc_) delete comps_; + } + + void InitVisit(const Fst &fst) {} + + bool InitState(StateId s, StateId root) { + ++nstates_; + if (comps_->FindSet(s) == kNoStateId) comps_->MakeSet(s); + return true; + } + + bool WhiteArc(StateId s, const Arc &arc) { + comps_->MakeSet(arc.nextstate); + comps_->Union(s, arc.nextstate); + return true; + } + + bool GreyArc(StateId s, const Arc &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + bool BlackArc(StateId s, const Arc &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + void FinishState(StateId s) {} + + void FinishVisit() { + if (cc_) GetCcVector(cc_); + } + + // Returns number of components. + // cc[i]: connected component number for state i. + int GetCcVector(std::vector *cc) { + cc->clear(); + cc->resize(nstates_, kNoStateId); + StateId ncomp = 0; + for (StateId s = 0; s < nstates_; ++s) { + const auto rep = comps_->FindSet(s); + auto &comp = (*cc)[rep]; + if (comp == kNoStateId) { + comp = ncomp; + ++ncomp; + } + (*cc)[s] = comp; + } + return ncomp; + } + + private: + UnionFind *comps_; // Components. + std::vector *cc_; // State's cc number. + StateId nstates_; // State count. +}; + +// Finds and returns strongly-connected components, accessible and +// coaccessible states and related properties. Uses Tarjan's single +// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer +// Algorithms", 189pp). Use with DfsVisit(); +template +class SccVisitor { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // scc[i]: strongly-connected component number for state i. + // SCC numbers will be in topological order for acyclic input. + // access[i]: accessibility of state i. + // coaccess[i]: coaccessibility of state i. + // Any of above can be NULL. + // props: related property bits (cyclicity, initial cyclicity, + // accessibility, coaccessibility) set/cleared (o.w. unchanged). + SccVisitor(std::vector *scc, std::vector *access, + std::vector *coaccess, uint64 *props) + : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {} + explicit SccVisitor(uint64 *props) + : scc_(nullptr), access_(nullptr), coaccess_(nullptr), props_(props) {} + + void InitVisit(const Fst &fst); + + bool InitState(StateId s, StateId root); + + bool TreeArc(StateId s, const Arc &arc) { return true; } + + bool BackArc(StateId s, const Arc &arc) { + const auto t = arc.nextstate; + if ((*dfnumber_)[t] < (*lowlink_)[s]) (*lowlink_)[s] = (*dfnumber_)[t]; + if ((*coaccess_)[t]) (*coaccess_)[s] = true; + *props_ |= kCyclic; + *props_ &= ~kAcyclic; + if (t == start_) { + *props_ |= kInitialCyclic; + *props_ &= ~kInitialAcyclic; + } + return true; + } + + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + const auto t = arc.nextstate; + if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ && (*onstack_)[t] && + (*dfnumber_)[t] < (*lowlink_)[s]) { + (*lowlink_)[s] = (*dfnumber_)[t]; + } + if ((*coaccess_)[t]) (*coaccess_)[s] = true; + return true; + } + + // Last argument always ignored, but required by the interface. + void FinishState(StateId state, StateId p, const Arc *); + + void FinishVisit() { + // Numbers SCCs in topological order when acyclic. + if (scc_) { + for (size_t s = 0; s < scc_->size(); ++s) { + (*scc_)[s] = nscc_ - 1 - (*scc_)[s]; + } + } + if (coaccess_internal_) delete coaccess_; + dfnumber_.reset(); + lowlink_.reset(); + onstack_.reset(); + scc_stack_.reset(); + } + + private: + std::vector *scc_; // State's scc number. + std::vector *access_; // State's accessibility. + std::vector *coaccess_; // State's coaccessibility. + uint64 *props_; + const Fst *fst_; + StateId start_; + StateId nstates_; // State count. + StateId nscc_; // SCC count. + bool coaccess_internal_; + std::unique_ptr> dfnumber_; // State discovery times. + std::unique_ptr> + lowlink_; // lowlink[state] == dfnumber[state] => SCC root + std::unique_ptr> onstack_; // Is a state on the SCC stack? + std::unique_ptr> + scc_stack_; // SCC stack, with random access. +}; + +template +inline void SccVisitor::InitVisit(const Fst &fst) { + if (scc_) scc_->clear(); + if (access_) access_->clear(); + if (coaccess_) { + coaccess_->clear(); + coaccess_internal_ = false; + } else { + coaccess_ = new std::vector; + coaccess_internal_ = true; + } + *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible; + *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible); + fst_ = &fst; + start_ = fst.Start(); + nstates_ = 0; + nscc_ = 0; + dfnumber_.reset(new std::vector()); + lowlink_.reset(new std::vector()); + onstack_.reset(new std::vector()); + scc_stack_.reset(new std::vector()); +} + +template +inline bool SccVisitor::InitState(StateId s, StateId root) { + scc_stack_->push_back(s); + if (static_cast(dfnumber_->size()) <= s) { + if (scc_) scc_->resize(s + 1, -1); + if (access_) access_->resize(s + 1, false); + coaccess_->resize(s + 1, false); + dfnumber_->resize(s + 1, -1); + lowlink_->resize(s + 1, -1); + onstack_->resize(s + 1, false); + } + (*dfnumber_)[s] = nstates_; + (*lowlink_)[s] = nstates_; + (*onstack_)[s] = true; + if (root == start_) { + if (access_) (*access_)[s] = true; + } else { + if (access_) (*access_)[s] = false; + *props_ |= kNotAccessible; + *props_ &= ~kAccessible; + } + ++nstates_; + return true; +} + +template +inline void SccVisitor::FinishState(StateId s, StateId p, const Arc *) { + if (fst_->Final(s) != Weight::Zero()) (*coaccess_)[s] = true; + if ((*dfnumber_)[s] == (*lowlink_)[s]) { // Root of new SCC. + bool scc_coaccess = false; + auto i = scc_stack_->size(); + StateId t; + do { + t = (*scc_stack_)[--i]; + if ((*coaccess_)[t]) scc_coaccess = true; + } while (s != t); + do { + t = scc_stack_->back(); + if (scc_) (*scc_)[t] = nscc_; + if (scc_coaccess) (*coaccess_)[t] = true; + (*onstack_)[t] = false; + scc_stack_->pop_back(); + } while (s != t); + if (!scc_coaccess) { + *props_ |= kNotCoAccessible; + *props_ &= ~kCoAccessible; + } + ++nscc_; + } + if (p != kNoStateId) { + if ((*coaccess_)[s]) (*coaccess_)[p] = true; + if ((*lowlink_)[s] < (*lowlink_)[p]) (*lowlink_)[p] = (*lowlink_)[s]; + } +} + +// Trims an FST, removing states and arcs that are not on successful paths. +// This version modifies its input. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(V + E) +// +// where V = # of states and E = # of arcs. +template +void Connect(MutableFst *fst) { + using StateId = typename Arc::StateId; + std::vector access; + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + std::vector dstates; + dstates.reserve(access.size()); + for (StateId s = 0; s < access.size(); ++s) { + if (!access[s] || !coaccess[s]) dstates.push_back(s); + } + fst->DeleteStates(dstates); + fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible); +} + +// Returns an acyclic FST where each SCC in the input FST has been condensed to +// a single state with transitions between SCCs retained and within SCCs +// dropped. Also populates 'scc' with a mapping from input to output states. +template +void Condense(const Fst &ifst, MutableFst *ofst, + std::vector *scc) { + using StateId = typename Arc::StateId; + ofst->DeleteStates(); + uint64 props = 0; + SccVisitor scc_visitor(scc, nullptr, nullptr, &props); + DfsVisit(ifst, &scc_visitor); + const auto iter = std::max_element(scc->cbegin(), scc->cend()); + if (iter == scc->cend()) return; + const StateId num_condensed_states = 1 + *iter; + ofst->ReserveStates(num_condensed_states); + for (StateId c = 0; c < num_condensed_states; ++c) { + ofst->AddState(); + } + for (StateId s = 0; s < scc->size(); ++s) { + const auto c = (*scc)[s]; + if (s == ifst.Start()) ofst->SetStart(c); + const auto weight = ifst.Final(s); + if (weight != Arc::Weight::Zero()) + ofst->SetFinal(c, Plus(ofst->Final(c), weight)); + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + const auto nextc = (*scc)[arc.nextstate]; + if (nextc != c) { + Arc condensed_arc = arc; + condensed_arc.nextstate = nextc; + ofst->AddArc(c, std::move(condensed_arc)); + } + } + } + ofst->SetProperties(kAcyclic | kInitialAcyclic, kAcyclic | kInitialAcyclic); +} + +} // namespace fst + +#endif // FST_CONNECT_H_ diff --git a/projects/llm_framework/include/fst/const-fst.h b/projects/llm_framework/include/fst/const-fst.h new file mode 100644 index 00000000..09c81c7b --- /dev/null +++ b/projects/llm_framework/include/fst/const-fst.h @@ -0,0 +1,485 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Simple concrete immutable FST whose states and arcs are each stored in +// single arrays. + +#ifndef FST_CONST_FST_H_ +#define FST_CONST_FST_H_ + +#include +#include +#include + +// Google-only... +// ...Google-only +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +template +class ConstFst; + +template +void Cast(const F &, G *); + +namespace internal { + +// States and arcs each implemented by single arrays, templated on the +// Arc definition. Unsigned is used to represent indices into the arc array. +template +class ConstFstImpl : public FstImpl { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + + ConstFstImpl() + : states_(nullptr), + arcs_(nullptr), + narcs_(0), + nstates_(0), + start_(kNoStateId) { + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + SetType(type); + SetProperties(kNullProperties | kStaticProperties); + } + + explicit ConstFstImpl(const Fst &fst); + + StateId Start() const { return start_; } + + Weight Final(StateId s) const { return states_[s].weight; } + + StateId NumStates() const { return nstates_; } + + size_t NumArcs(StateId s) const { return states_[s].narcs; } + + size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; } + + size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; } + + static ConstFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + const Arc *Arcs(StateId s) const { return arcs_ + states_[s].pos; } + + // Provide information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = nstates_; + } + + // Provide information needed for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + data->base = nullptr; + data->arcs = arcs_ + states_[s].pos; + data->narcs = states_[s].narcs; + data->ref_count = nullptr; + } + + private: + // Used to find narcs_ and nstates_ in Write. + friend class ConstFst; + + // States implemented by array *states_ below, arcs by (single) *arcs_. + struct ConstState { + Weight weight; // Final weight. + Unsigned pos; // Start of state's arcs in *arcs_. + Unsigned narcs; // Number of arcs (per state). + Unsigned niepsilons; // Number of input epsilons. + Unsigned noepsilons; // Number of output epsilons. + + ConstState() : weight(Weight::Zero()) {} + }; + + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded; + // Current unaligned file format version. The unaligned version was added and + // made the default since the aligned version does not work on pipes. + static constexpr int kFileVersion = 2; + // Current aligned file format version. + static constexpr int kAlignedFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + std::unique_ptr states_region_; // Mapped file for states. + std::unique_ptr arcs_region_; // Mapped file for arcs. + ConstState *states_; // States representation. + Arc *arcs_; // Arcs representation. + size_t narcs_; // Number of arcs. + StateId nstates_; // Number of states. + StateId start_; // Initial state. + + ConstFstImpl(const ConstFstImpl &) = delete; + ConstFstImpl &operator=(const ConstFstImpl &) = delete; +}; + +template +constexpr uint64 ConstFstImpl::kStaticProperties; + +template +constexpr int ConstFstImpl::kFileVersion; + +template +constexpr int ConstFstImpl::kAlignedFileVersion; + +template +constexpr int ConstFstImpl::kMinFileVersion; + +template +ConstFstImpl::ConstFstImpl(const Fst &fst) + : narcs_(0), nstates_(0) { + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + SetType(type); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + start_ = fst.Start(); + // Counts states and arcs. + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + narcs_ += fst.NumArcs(siter.Value()); + } + states_region_.reset(MappedFile::Allocate(nstates_ * sizeof(*states_))); + arcs_region_.reset(MappedFile::Allocate(narcs_ * sizeof(*arcs_))); + states_ = reinterpret_cast(states_region_->mutable_data()); + arcs_ = reinterpret_cast(arcs_region_->mutable_data()); + size_t pos = 0; + for (StateId s = 0; s < nstates_; ++s) { + states_[s].weight = fst.Final(s); + states_[s].pos = pos; + states_[s].narcs = 0; + states_[s].niepsilons = 0; + states_[s].noepsilons = 0; + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + ++states_[s].narcs; + if (arc.ilabel == 0) ++states_[s].niepsilons; + if (arc.olabel == 0) ++states_[s].noepsilons; + arcs_[pos] = arc; + ++pos; + } + } + const auto props = + fst.Properties(kMutable, false) + ? fst.Properties(kCopyProperties, true) + : CheckProperties( + fst, kCopyProperties & ~kWeightedCycles & ~kUnweightedCycles, + kCopyProperties); + SetProperties(props | kStaticProperties); +} + +template +ConstFstImpl *ConstFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { + using ConstState = typename ConstFstImpl::ConstState; + std::unique_ptr> impl( + new ConstFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->start_ = hdr.Start(); + impl->nstates_ = hdr.NumStates(); + impl->narcs_ = hdr.NumArcs(); + // Ensures compatibility. + if (hdr.Version() == kAlignedFileVersion) { + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + } + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + return nullptr; + } + size_t b = impl->nstates_ * sizeof(ConstState); + impl->states_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !impl->states_region_) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->states_ = + reinterpret_cast(impl->states_region_->mutable_data()); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + return nullptr; + } + b = impl->narcs_ * sizeof(Arc); + impl->arcs_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !impl->arcs_region_) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->arcs_ = reinterpret_cast(impl->arcs_region_->mutable_data()); + return impl.release(); +} + +} // namespace internal + +// Simple concrete immutable FST. This class attaches interface to +// implementation and handles reference counting, delegating most methods to +// ImplToExpandedFst. The unsigned type U is used to represent indices into the +// arc array (default declared in fst-decl.h). +template +class ConstFst : public ImplToExpandedFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::ConstFstImpl; + using ConstState = typename Impl::ConstState; + + friend class StateIterator>; + friend class ArcIterator>; + + template + void friend Cast(const F &, G *); + + ConstFst() : ImplToExpandedFst(std::make_shared()) {} + + explicit ConstFst(const Fst &fst) + : ImplToExpandedFst(std::make_shared(fst)) {} + + ConstFst(const ConstFst &fst, bool safe = false) + : ImplToExpandedFst(fst) {} + + // Gets a copy of this ConstFst. See Fst<>::Copy() for further doc. + ConstFst *Copy(bool safe = false) const override { + return new ConstFst(*this, safe); + } + + // Reads a ConstFst from an input stream, returning nullptr on error. + static ConstFst *Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new ConstFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a ConstFst from a file; return nullptr on error; empty filename reads + // from standard input. + static ConstFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new ConstFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return WriteFst(*this, strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + private: + explicit ConstFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + using ImplToFst>::GetImpl; + + // Uses overloading to extract the type of the argument. + static const Impl *GetImplIfConstFst(const ConstFst &const_fst) { + return const_fst.GetImpl(); + } + + // NB: this does not give privileged treatment to subtypes of ConstFst. + template + static Impl *GetImplIfConstFst(const FST &fst) { + return nullptr; + } + + ConstFst &operator=(const ConstFst &) = delete; +}; + +// Writes FST in Const format, potentially with a pass over the machine before +// writing to compute number of states and arcs. +template +template +bool ConstFst::WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts) { + const auto file_version = + opts.align ? internal::ConstFstImpl::kAlignedFileVersion + : internal::ConstFstImpl::kFileVersion; + size_t num_arcs = 0; // To silence -Wsometimes-uninitialized warnings. + size_t num_states = 0; // Ditto. + std::streamoff start_offset = 0; + bool update_header = true; + if (const auto *impl = GetImplIfConstFst(fst)) { + num_arcs = impl->narcs_; + num_states = impl->nstates_; + update_header = false; + } else if (opts.stream_write || (start_offset = strm.tellp()) == -1) { + // precompute values needed for header when we cannot seek to rewrite it. + num_arcs = 0; + num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + num_arcs += fst.NumArcs(siter.Value()); + ++num_states; + } + update_header = false; + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + const auto properties = + fst.Properties(kCopyProperties, true) | + internal::ConstFstImpl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, type, + properties, &hdr); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after header"; + return false; + } + size_t pos = 0; + size_t states = 0; + typename ConstFst::ConstState state; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + state.weight = fst.Final(s); + state.pos = pos; + state.narcs = fst.NumArcs(s); + state.niepsilons = fst.NumInputEpsilons(s); + state.noepsilons = fst.NumOutputEpsilons(s); + strm.write(reinterpret_cast(&state), sizeof(state)); + pos += state.narcs; + ++states; + } + hdr.SetNumStates(states); + hdr.SetNumArcs(pos); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + for (ArcIterator aiter(fst, siter.Value()); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); +// Google-only... +#ifdef MEMORY_SANITIZER + // arc may contain padding which has unspecified contents. Tell MSAN to + // not complain about it when writing it to a file. + ANNOTATE_MEMORY_IS_INITIALIZED(reinterpret_cast(&arc), + sizeof(arc)); +#endif + // ...Google-only + strm.write(reinterpret_cast(&arc), sizeof(arc)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "ConstFst::WriteFst: write failed: " << opts.source; + return false; + } + if (update_header) { + return internal::FstImpl::UpdateFstHeader( + fst, strm, opts, file_version, type, properties, &hdr, start_offset); + } else { + if (hdr.NumStates() != num_states) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + if (hdr.NumArcs() != num_arcs) { + LOG(ERROR) << "Inconsistent number of arcs observed during write"; + return false; + } + } + return true; +} + +// Specialization for ConstFst; see generic version in fst.h for sample usage +// (but use the ConstFst type instead). This version should inline. +template +class StateIterator> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const ConstFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + const StateId nstates_; + StateId s_; +}; + +// Specialization for ConstFst; see generic version in fst.h for sample usage +// (but use the ConstFst type instead). This version should inline. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ConstFst &fst, StateId s) + : arcs_(fst.GetImpl()->Arcs(s)), + narcs_(fst.GetImpl()->NumArcs(s)), + i_(0) {} + + bool Done() const { return i_ >= narcs_; } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32, uint32) {} + + private: + const Arc *arcs_; + size_t narcs_; + size_t i_; +}; + +// A useful alias when using StdArc. +using StdConstFst = ConstFst; + +} // namespace fst + +#endif // FST_CONST_FST_H_ diff --git a/projects/llm_framework/include/fst/determinize.h b/projects/llm_framework/include/fst/determinize.h new file mode 100644 index 00000000..736a140e --- /dev/null +++ b/projects/llm_framework/include/fst/determinize.h @@ -0,0 +1,1093 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to determinize an FST. + +#ifndef FST_DETERMINIZE_H_ +#define FST_DETERMINIZE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Common divisors are used in determinization to compute transition weights. +// In the simplest case, it is the same as semiring Plus, but other choices +// permit more efficient determinization when the output contains strings. + +// The default common divisor uses the semiring Plus. +template +struct DefaultCommonDivisor { + public: + using Weight = W; + + Weight operator()(const Weight &w1, const Weight &w2) const { + return Plus(w1, w2); + } +}; + +// The label common divisor for a (left) string semiring selects a single +// letter common prefix or the empty string. This is used in the +// determinization of output strings so that at most a single letter will +// appear in the output of a transtion. +template +struct LabelCommonDivisor { + public: + using Weight = StringWeight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + typename Weight::Iterator iter1(w1); + typename Weight::Iterator iter2(w2); + if (!(StringWeight::Properties() & kLeftSemiring)) { + FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring"; + return Weight::NoWeight(); + } else if (w1.Size() == 0 || w2.Size() == 0) { + return Weight::One(); + } else if (w1 == Weight::Zero()) { + return Weight(iter2.Value()); + } else if (w2 == Weight::Zero()) { + return Weight(iter1.Value()); + } else if (iter1.Value() == iter2.Value()) { + return Weight(iter1.Value()); + } else { + return Weight::One(); + } + } +}; + +// The gallic common divisor uses the label common divisor on the string +// component and the common divisor on the weight component, which defaults to +// the default common divisor. +template > +class GallicCommonDivisor { + public: + using Weight = GallicWeight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + return Weight(label_common_divisor_(w1.Value1(), w2.Value1()), + weight_common_divisor_(w1.Value2(), w2.Value2())); + } + + private: + LabelCommonDivisor label_common_divisor_; + CommonDivisor weight_common_divisor_; +}; + +// Specialization for general GALLIC weight. +template +class GallicCommonDivisor { + public: + using Weight = GallicWeight; + using GRWeight = GallicWeight; + using Iterator = + UnionWeightIterator>; + + Weight operator()(const Weight &w1, const Weight &w2) const { + auto weight = GRWeight::Zero(); + for (Iterator iter(w1); !iter.Done(); iter.Next()) { + weight = common_divisor_(weight, iter.Value()); + } + for (Iterator iter(w2); !iter.Done(); iter.Next()) { + weight = common_divisor_(weight, iter.Value()); + } + return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight); + } + + private: + GallicCommonDivisor common_divisor_; +}; + +namespace internal { + +// Represents an element in a subset +template +struct DeterminizeElement { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DeterminizeElement(StateId s, Weight weight) + : state_id(s), weight(std::move(weight)) {} + + inline bool operator==(const DeterminizeElement &element) const { + return state_id == element.state_id && weight == element.weight; + } + + inline bool operator!=(const DeterminizeElement &element) const { + return !(*this == element); + } + + inline bool operator<(const DeterminizeElement &element) const { + return state_id < element.state_id; + } + + StateId state_id; // Input state ID. + Weight weight; // Residual weight. +}; + +// Represents a weighted subset and determinization filter state +template +struct DeterminizeStateTuple { + using Arc = A; + using Element = DeterminizeElement; + using Subset = std::forward_list; + + DeterminizeStateTuple() : filter_state(FilterState::NoState()) {} + + inline bool operator==( + const DeterminizeStateTuple &tuple) const { + return (tuple.filter_state == filter_state) && (tuple.subset == subset); + } + + inline bool operator!=( + const DeterminizeStateTuple &tuple) const { + return (tuple.filter_state != filter_state) || (tuple.subset != subset); + } + + Subset subset; + FilterState filter_state; +}; + +// Proto-transition for determinization. +template +struct DeterminizeArc { + using Arc = typename StateTuple::Arc; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + DeterminizeArc() + : label(kNoLabel), weight(Weight::Zero()), dest_tuple(nullptr) {} + + explicit DeterminizeArc(const Arc &arc) + : label(arc.ilabel), weight(Weight::Zero()), dest_tuple(new StateTuple) {} + + Label label; // Arc label. + Weight weight; // Arc weight. + StateTuple *dest_tuple; // Destination subset and filter state. +}; + +} // namespace internal + +// Determinization filters are used to compute destination state tuples based +// on the source tuple, transition, and destination element or on similar +// super-final transition information. The filter operates on a map between a +// label and the corresponding destination state tuples. It must define the map +// type LabelMap. The default filter is used for weighted determinization. +// A determinize filter for implementing weighted determinization. +template +class DefaultDeterminizeFilter { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = CharFilterState; + using Element = internal::DeterminizeElement; + using StateTuple = internal::DeterminizeStateTuple; + using LabelMap = std::map>; + + // This is needed e.g. to go into the gallic domain for transducers. + template + struct rebind { + using Other = DefaultDeterminizeFilter; + }; + + explicit DefaultDeterminizeFilter(const Fst &fst) : fst_(fst.Copy()) {} + + // This is needed (e.g.) to go into the gallic domain for transducers. + // Ownership of the templated filter argument is given to this class. + template + DefaultDeterminizeFilter(const Fst &fst, Filter *filter) + : fst_(fst.Copy()) { + delete filter; + } + + // Copy constructor; the FST can be passed if it has been deep-copied. + DefaultDeterminizeFilter(const DefaultDeterminizeFilter &filter, + const Fst *fst = nullptr) + : fst_(fst ? fst->Copy() : filter.fst_->Copy()) {} + + FilterState Start() const { return FilterState(0); } + + // Does no work. + void SetState(StateId s, const StateTuple &tuple) {} + + // Filters transition, possibly modifying label map. Returns true if arc is + // added to the label map. + bool FilterArc(const Arc &arc, const Element &src_element, + Element &&dest_element, LabelMap *label_map) const { + // Adds element to unique state tuple for arc label. + auto &det_arc = (*label_map)[arc.ilabel]; + if (det_arc.label == kNoLabel) { + det_arc = internal::DeterminizeArc(arc); + det_arc.dest_tuple->filter_state = FilterState(0); + } + det_arc.dest_tuple->subset.push_front(std::move(dest_element)); + return true; + } + + // Filters super-final transition, returning new final weight. + Weight FilterFinal(Weight weight, const Element &element) { return weight; } + + static uint64 Properties(uint64 props) { return props; } + + private: + std::unique_ptr> fst_; +}; + +// Determinization state table interface: +// +// template +// class DeterminizeStateTable { +// public: +// using StateId = typename Arc::StateId; +// using StateTuple = internal::DeterminizeStateTuple; +// +// // Required sub-class. This is needed (e.g.) to go into the gallic domain. +// template +// struct rebind { +// using Other = DeterminizeStateTable; +// } +// +// // Required constuctor. +// DeterminizeStateTable(); +// +// // Required copy constructor that does not copy state. +// DeterminizeStateTable(const DeterminizeStateTable +// &table); +// +// // Looks up state ID by state tuple; if it doesn't exist, then adds it. +// // FindState takes ownership of the state tuple argument so that it +// // doesn't have to copy it if it creates a new state. +// StateId FindState(StateTuple *tuple); +// +// // Looks up state tuple by ID. +// const StateTuple *Tuple(StateId id) const; +// }; + +// The default determinization state table based on the compact hash bi-table. +template +class DefaultDeterminizeStateTable { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StateTuple = internal::DeterminizeStateTuple; + using Element = typename StateTuple::Element; + using Subset = typename StateTuple::Subset; + + template + struct rebind { + using Other = DefaultDeterminizeStateTable; + }; + + explicit DefaultDeterminizeStateTable(size_t table_size = 0) + : table_size_(table_size), tuples_(table_size_) {} + + DefaultDeterminizeStateTable( + const DefaultDeterminizeStateTable &table) + : table_size_(table.table_size_), tuples_(table_size_) {} + + ~DefaultDeterminizeStateTable() { + for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s); + } + + // Finds the state corresponding to a state tuple. Only creates a new state if + // the tuple is not found. FindState takes ownership of the tuple argument so + // that it doesn't have to copy it if it creates a new state. + StateId FindState(StateTuple *tuple) { + const StateId ns = tuples_.Size(); + const auto s = tuples_.FindId(tuple); + if (s != ns) delete tuple; // Tuple found. + return s; + } + + const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); } + + private: + // Comparison object for StateTuples. + class StateTupleEqual { + public: + bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const { + return *tuple1 == *tuple2; + } + }; + + // Hash function for StateTuples. + class StateTupleKey { + public: + size_t operator()(const StateTuple *tuple) const { + size_t h = tuple->filter_state.Hash(); + for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) { + const size_t h1 = it->state_id; + static constexpr auto lshift = 5; + static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5; + h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ it->weight.Hash(); + } + return h; + } + }; + + size_t table_size_; + CompactHashBiTable + tuples_; + + DefaultDeterminizeStateTable &operator=( + const DefaultDeterminizeStateTable &) = delete; +}; + +// Determinization type. +enum DeterminizeType { + // Input transducer is known to be functional (or error). + DETERMINIZE_FUNCTIONAL, // Input transducer is functional (error if not). + // Input transducer is not known to be functional. + DETERMINIZE_NONFUNCTIONAL, + // Input transducer is not known to be functional but only keep the min of + // of ambiguous outputs. + DETERMINIZE_DISAMBIGUATE +}; + +// Options for finite-state transducer determinization templated on the arc +// type, common divisor, the determinization filter and the state table. +// DeterminizeFst takes ownership of the determinization filter and state table, +// if provided. +template , + class Filter = DefaultDeterminizeFilter, + class StateTable = + DefaultDeterminizeStateTable> +struct DeterminizeFstOptions : public CacheOptions { + using Label = typename Arc::Label; + + float delta; // Quantization delta for subset weights. + Label subsequential_label; // Label used for residual final output + // when producing subsequential transducers. + DeterminizeType type; // Determinization type. + bool increment_subsequential_label; // When creating several subsequential + // arcs at a given state, make their + // label distinct by incrementing. + Filter *filter; // Determinization filter; + // DeterminizeFst takes ownership. + StateTable *state_table; // Determinization state table; + // DeterminizeFst takes ownership. + + explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheOptions(opts), + delta(delta), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label), + filter(filter), + state_table(state_table) {} + + explicit DeterminizeFstOptions(float delta = kDelta, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : delta(delta), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label), + filter(filter), + state_table(state_table) {} +}; + +namespace internal { + +// Implementation of delayed DeterminizeFst. This base class is +// common to the variants that implement acceptor and transducer +// determinization. +template +class DeterminizeFstImplBase : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + template + DeterminizeFstImplBase( + const Fst &fst, + const DeterminizeFstOptions &opts) + : CacheImpl(opts), fst_(fst.Copy()) { + SetType("determinize"); + const auto iprops = fst.Properties(kFstProperties, false); + const auto dprops = + DeterminizeProperties(iprops, opts.subsequential_label != 0, + opts.type == DETERMINIZE_NONFUNCTIONAL + ? opts.increment_subsequential_label + : true); + SetProperties(Filter::Properties(dprops), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + DeterminizeFstImplBase(const DeterminizeFstImplBase &impl) + : CacheImpl(impl), fst_(impl.fst_->Copy(true)) { + SetType("determinize"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual DeterminizeFstImplBase *Copy() const = 0; + + StateId Start() { + if (!HasStart()) { + const auto start = ComputeStart(); + if (start != kNoStateId) SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, ComputeFinal(s)); + return CacheImpl::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + virtual StateId ComputeStart() = 0; + + virtual Weight ComputeFinal(StateId s) = 0; + + const Fst &GetFst() const { return *fst_; } + + private: + std::unique_ptr> fst_; // Input FST. +}; + +// Implementation of delayed determinization for weighted acceptors. +template +class DeterminizeFsaImpl : public DeterminizeFstImplBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = typename Filter::FilterState; + using StateTuple = internal::DeterminizeStateTuple; + using Element = typename StateTuple::Element; + using Subset = typename StateTuple::Subset; + using LabelMap = typename Filter::LabelMap; + + using FstImpl::SetProperties; + using DeterminizeFstImplBase::GetFst; + using DeterminizeFstImplBase::SetArcs; + + DeterminizeFsaImpl( + const Fst &fst, const std::vector *in_dist, + std::vector *out_dist, + const DeterminizeFstOptions &opts) + : DeterminizeFstImplBase(fst, opts), + delta_(opts.delta), + in_dist_(in_dist), + out_dist_(out_dist), + filter_(opts.filter ? opts.filter : new Filter(fst)), + state_table_(opts.state_table ? opts.state_table : new StateTable()) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst: Argument not an acceptor"; + SetProperties(kError, kError); + } + if (!(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "DeterminizeFst: Weight must be left distributive: " + << Weight::Type(); + SetProperties(kError, kError); + } + if (out_dist_) out_dist_->clear(); + } + + DeterminizeFsaImpl( + const DeterminizeFsaImpl &impl) + : DeterminizeFstImplBase(impl), + delta_(impl.delta_), + in_dist_(nullptr), + out_dist_(nullptr), + filter_(new Filter(*impl.filter_, &GetFst())), + state_table_(new StateTable(*impl.state_table_)) { + if (impl.out_dist_) { + FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector"; + SetProperties(kError, kError); + } + } + + DeterminizeFsaImpl *Copy() + const override { + return new DeterminizeFsaImpl( + *this); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (GetFst().Properties(kError, false))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + StateId ComputeStart() override { + const auto s = GetFst().Start(); + if (s == kNoStateId) return kNoStateId; + auto *tuple = new StateTuple; + tuple->subset.emplace_front(s, Weight::One()); + tuple->filter_state = filter_->Start(); + return FindState(tuple); + } + + Weight ComputeFinal(StateId s) override { + const auto *tuple = state_table_->Tuple(s); + filter_->SetState(s, *tuple); + auto final_weight = Weight::Zero(); + for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) { + const auto &element = *it; + final_weight = + Plus(final_weight, + Times(element.weight, GetFst().Final(element.state_id))); + final_weight = filter_->FilterFinal(final_weight, element); + if (!final_weight.Member()) SetProperties(kError, kError); + } + return final_weight; + } + + StateId FindState(StateTuple *tuple) { + const auto s = state_table_->FindState(tuple); + if (in_dist_ && out_dist_->size() <= s) { + out_dist_->push_back(ComputeDistance(tuple->subset)); + } + return s; + } + + // Computes distance from a state to the final states in the DFA given the + // distances in the NFA. + Weight ComputeDistance(const Subset &subset) { + auto outd = Weight::Zero(); + for (auto it = subset.begin(); it != subset.end(); ++it) { + const auto &element = *it; + const auto ind = + (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id] + : Weight::Zero()); + outd = Plus(outd, Times(element.weight, ind)); + } + return outd; + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) override { + LabelMap label_map; + GetLabelMap(s, &label_map); + for (auto it = label_map.begin(); it != label_map.end(); ++it) { + AddArc(s, std::move(it->second)); + } + SetArcs(s); + } + + private: + using DetArc = internal::DeterminizeArc; + + // Constructs proto-determinization transition, including destination subset, + // per label. + void GetLabelMap(StateId s, LabelMap *label_map) { + const auto *src_tuple = state_table_->Tuple(s); + filter_->SetState(s, *src_tuple); + for (auto it = src_tuple->subset.begin(); it != src_tuple->subset.end(); + ++it) { + const auto &src_element = *it; + for (ArcIterator> aiter(GetFst(), src_element.state_id); + !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + Element dest_element(arc.nextstate, + Times(src_element.weight, arc.weight)); + filter_->FilterArc(arc, src_element, std::move(dest_element), + label_map); + } + } + for (auto it = label_map->begin(); it != label_map->end(); ++it) { + NormArc(&it->second); + } + } + + // Sorts subsets and removes duplicate elements, normalizing transition and + // subset weights. + void NormArc(DetArc *det_arc) { + auto *dest_tuple = det_arc->dest_tuple; + dest_tuple->subset.sort(); + auto piter = dest_tuple->subset.begin(); + for (auto diter = dest_tuple->subset.begin(); + diter != dest_tuple->subset.end();) { + auto &dest_element = *diter; + auto &prev_element = *piter; + // Computes arc weight. + det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight); + if (piter != diter && dest_element.state_id == prev_element.state_id) { + // Found duplicate state: sums state weight and deletes duplicate. + prev_element.weight = Plus(prev_element.weight, dest_element.weight); + if (!prev_element.weight.Member()) SetProperties(kError, kError); + ++diter; + dest_tuple->subset.erase_after(piter); + } else { + piter = diter; + ++diter; + } + } + // Divides out label weight from destination subset elements, quantizing to + // ensure comparisons are effective. + for (auto diter = dest_tuple->subset.begin(); + diter != dest_tuple->subset.end(); ++diter) { + auto &dest_element = *diter; + dest_element.weight = + Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT); + dest_element.weight = dest_element.weight.Quantize(delta_); + } + } + + // Adds an arc from state S to the destination state associated with state + // tuple in det_arc as created by GetLabelMap. + void AddArc(StateId s, DetArc &&det_arc) { + CacheImpl::EmplaceArc( + s, det_arc.label, det_arc.label, std::move(det_arc.weight), + FindState(det_arc.dest_tuple)); + } + + float delta_; // Quantization delta for weights. + const std::vector *in_dist_; // Distance to final NFA states. + std::vector *out_dist_; // Distance to final DFA states. + + // FIXME(kbg): Ought to be static const? + CommonDivisor common_divisor_; + std::unique_ptr filter_; + std::unique_ptr state_table_; +}; + +// Implementation of delayed determinization for transducers. Transducer +// determinization is implemented by mapping the input to the Gallic semiring as +// an acceptor whose weights contain the output strings and using acceptor +// determinization above to determinize that acceptor. +template +class DeterminizeFstImpl : public DeterminizeFstImplBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ToMapper = ToGallicMapper; + using ToArc = typename ToMapper::ToArc; + using ToFst = ArcMapFst; + using FromMapper = FromGallicMapper; + using FromFst = ArcMapFst; + + using ToCommonDivisor = GallicCommonDivisor; + using ToFilter = typename Filter::template rebind::Other; + using ToFilterState = typename ToFilter::FilterState; + using ToStateTable = + typename StateTable::template rebind::Other; + using FactorIterator = GallicFactor; + + using FstImpl::SetProperties; + using DeterminizeFstImplBase::GetFst; + using CacheBaseImpl>::GetCacheGc; + using CacheBaseImpl>::GetCacheLimit; + + DeterminizeFstImpl( + const Fst &fst, + const DeterminizeFstOptions &opts) + : DeterminizeFstImplBase(fst, opts), + delta_(opts.delta), + subsequential_label_(opts.subsequential_label), + increment_subsequential_label_(opts.increment_subsequential_label) { + if (opts.state_table) { + FSTERROR() << "DeterminizeFst: " + << "A state table can not be passed with transducer input"; + SetProperties(kError, kError); + return; + } + Init(GetFst(), opts.filter); + } + + DeterminizeFstImpl( + const DeterminizeFstImpl &impl) + : DeterminizeFstImplBase(impl), + delta_(impl.delta_), + subsequential_label_(impl.subsequential_label_), + increment_subsequential_label_(impl.increment_subsequential_label_) { + Init(GetFst(), nullptr); + } + + DeterminizeFstImpl *Copy() + const override { + return new DeterminizeFstImpl( + *this); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (GetFst().Properties(kError, false) || + from_fst_->Properties(kError, false))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + StateId ComputeStart() override { return from_fst_->Start(); } + + Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); } + + void Expand(StateId s) override { + for (ArcIterator aiter(*from_fst_, s); !aiter.Done(); + aiter.Next()) { + CacheImpl::PushArc(s, aiter.Value()); + } + CacheImpl::SetArcs(s); + } + + private: + // Initialization of transducer determinization implementation, which is + // defined after DeterminizeFst since it calls it. + void Init(const Fst &fst, Filter *filter); + + float delta_; + Label subsequential_label_; + bool increment_subsequential_label_; + std::unique_ptr from_fst_; +}; + +} // namespace internal + +// Determinizes a weighted transducer. This version is a delayed +// FST. The result will be an equivalent FST that has the property +// that no state has two transitions with the same input label. +// For this algorithm, epsilon transitions are treated as regular +// symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) left +// divisible (valid for TropicalWeight and LogWeight for instance) and be +// zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0. +// +// Complexity: +// +// Determinizable: exponential (polynomial in the size of the output). +// Non-determinizable: does not terminate. +// +// The determinizable automata include all unweighted and all acyclic input. +// +// For more information, see: +// +// Mohri, M. 1997. Finite-state transducers in language and speech processing. +// Computational Linguistics 23(2): 269-311. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. +template +class DeterminizeFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::DeterminizeFstImplBase; + + friend class ArcIterator>; + friend class StateIterator>; + + template + friend class DeterminizeFstImpl; + + explicit DeterminizeFst(const Fst &fst) + : ImplToFst(CreateImpl(fst)) {} + + template + DeterminizeFst( + const Fst &fst, + const DeterminizeFstOptions + &opts = + DeterminizeFstOptions()) + : ImplToFst(CreateImpl(fst, opts)) {} + + // This acceptor-only version additionally computes the distance to final + // states in the output if provided with those distances for the input; this + // is useful for e.g., computing the k-shortest unique paths. + template + DeterminizeFst( + const Fst &fst, const std::vector *in_dist, + std::vector *out_dist, + const DeterminizeFstOptions + &opts = + DeterminizeFstOptions()) + : ImplToFst( + std::make_shared>( + fst, in_dist, out_dist, opts)) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst: " + << "Distance to final states computed for acceptors only"; + GetMutableImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + DeterminizeFst(const DeterminizeFst &fst, bool safe = false) + : ImplToFst(safe ? std::shared_ptr(fst.GetImpl()->Copy()) + : fst.GetSharedImpl()) {} + + // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc. + DeterminizeFst *Copy(bool safe = false) const override { + return new DeterminizeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + static std::shared_ptr CreateImpl(const Fst &fst) { + using D = DefaultCommonDivisor; + using F = DefaultDeterminizeFilter; + using T = DefaultDeterminizeStateTable; + const DeterminizeFstOptions opts; + return CreateImpl(fst, opts); + } + + template + static std::shared_ptr CreateImpl( + const Fst &fst, + const DeterminizeFstOptions + &opts) { + if (fst.Properties(kAcceptor, true)) { + // Calls implementation for acceptors. + return std::make_shared< + internal::DeterminizeFsaImpl>( + fst, nullptr, nullptr, opts); + } else if (opts.type == DETERMINIZE_DISAMBIGUATE) { + auto rv = std::make_shared>(fst, opts); + if (!(Weight::Properties() & kPath)) { + FSTERROR() << "DeterminizeFst: Weight needs to have the " + << "path property to disambiguate output: " + << Weight::Type(); + rv->SetProperties(kError, kError); + } + // Calls disambiguating implementation for non-functional transducers. + return rv; + } else if (opts.type == DETERMINIZE_FUNCTIONAL) { + // Calls implementation for functional transducers. + return std::make_shared>(fst, opts); + } else { // opts.type == DETERMINIZE_NONFUNCTIONAL + // Calls implementation for non functional transducers; + return std::make_shared>(fst, opts); + } + } + + DeterminizeFst &operator=(const DeterminizeFst &) = delete; +}; + +namespace internal { + +// Initialization of transducer determinization implementation, which is defined +// after DeterminizeFst since it calls it. +template +void DeterminizeFstImpl::Init(const Fst &fst, Filter *filter) { + // Mapper to an acceptor. + const ToFst to_fst(fst, ToMapper()); + auto *to_filter = filter ? new ToFilter(to_fst, filter) : nullptr; + // This recursive call terminates since it is to a (non-recursive) + // different constructor. + const CacheOptions copts(GetCacheGc(), GetCacheLimit()); + const DeterminizeFstOptions + dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter); + // Uses acceptor-only constructor to avoid template recursion. + const DeterminizeFst det_fsa(to_fst, nullptr, nullptr, dopts); + // Mapper back to transducer. + const FactorWeightOptions fopts( + CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_, + subsequential_label_, increment_subsequential_label_, + increment_subsequential_label_); + const FactorWeightFst factored_fst(det_fsa, fopts); + from_fst_.reset(new FromFst(factored_fst, FromMapper(subsequential_label_))); +} + +} // namespace internal + +// Specialization for DeterminizeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const DeterminizeFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for DeterminizeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const DeterminizeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void DeterminizeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful aliases when using StdArc. +using StdDeterminizeFst = DeterminizeFst; + +template +struct DeterminizeOptions { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + float delta; // Quantization delta for subset weights. + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + Label subsequential_label; // Label used for residual final output. + DeterminizeType type; + bool increment_subsequential_label; // When creating several subsequential + // arcs at a given state, make their + // label distinct by incrementation? + + explicit DeterminizeOptions(float delta = kDelta, + Weight weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false) + : delta(delta), + weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label) {} +}; + +// Determinizes a weighted transducer. This version writes the +// determinized Fst to an output MutableFst. The result will be an +// equivalent FST that has the property that no state has two +// transitions with the same input label. For this algorithm, epsilon +// transitions are treated as regular symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) +// left divisible (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Determinizable: exponential (polynomial in the size of the output) +// Non-determinizable: does not terminate +// +// The determinizable automata include all unweighted and all acyclic input. +template +void Determinize( + const Fst &ifst, MutableFst *ofst, + const DeterminizeOptions &opts = DeterminizeOptions()) { + using Weight = typename Arc::Weight; + DeterminizeFstOptions nopts; + nopts.delta = opts.delta; + nopts.subsequential_label = opts.subsequential_label; + nopts.type = opts.type; + nopts.increment_subsequential_label = opts.increment_subsequential_label; + nopts.gc_limit = 0; // Caches only the last state for fastest copy. + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + if (ifst.Properties(kAcceptor, false)) { + std::vector idistance; + std::vector odistance; + ShortestDistance(ifst, &idistance, true); + DeterminizeFst dfst(ifst, &idistance, &odistance, nopts); + PruneOptions> popts( + opts.weight_threshold, opts.state_threshold, AnyArcFilter(), + &odistance); + Prune(dfst, ofst, popts); + } else { + *ofst = DeterminizeFst(ifst, nopts); + Prune(ofst, opts.weight_threshold, opts.state_threshold); + } + } else { + *ofst = DeterminizeFst(ifst, nopts); + } +} + +} // namespace fst + +#endif // FST_DETERMINIZE_H_ diff --git a/projects/llm_framework/include/fst/dfs-visit.h b/projects/llm_framework/include/fst/dfs-visit.h new file mode 100644 index 00000000..a7b18a6c --- /dev/null +++ b/projects/llm_framework/include/fst/dfs-visit.h @@ -0,0 +1,202 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Depth-first search visitation. See visit.h for more general search queue +// disciplines. + +#ifndef FST_DFS_VISIT_H_ +#define FST_DFS_VISIT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Visitor Interface: class determining actions taken during a depth-first +// search-style visit. If any of the boolean member functions return false, the +// DFS is aborted by first calling FinishState() on all currently grey states +// and then calling FinishVisit(). +// +// This is similar to the more general visitor interface in visit.h, except +// that FinishState returns additional information appropriate only for a DFS +// and some methods names here are better suited to a DFS. +// +// template +// class Visitor { +// public: +// using StateId = typename Arc::StateId; +// +// Visitor(T *return_data); +// +// // Invoked before DFS visit. +// void InitVisit(const Fst &fst); +// +// // Invoked when state discovered (2nd arg is DFS tree root). +// bool InitState(StateId s, StateId root); +// +// // Invoked when tree arc to white/undiscovered state examined. +// bool TreeArc(StateId s, const Arc &arc); +// +// // Invoked when back arc to grey/unfinished state examined. +// bool BackArc(StateId s, const Arc &arc); +// +// // Invoked when forward or cross arc to black/finished state examined. +// bool ForwardOrCrossArc(StateId s, const Arc &arc); +// +// // Invoked when state finished ('s' is tree root, 'parent' is kNoStateId, +// // and 'arc' is nullptr). +// void FinishState(StateId s, StateId parent, const Arc *arc); +// +// // Invoked after DFS visit. +// void FinishVisit(); +// }; + +namespace internal { + +// An FST state's DFS stack state. +template +struct DfsState { + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + DfsState(const FST &fst, StateId s) : state_id(s), arc_iter(fst, s) {} + + void *operator new(size_t size, MemoryPool> *pool) { + return pool->Allocate(); + } + + static void Destroy(DfsState *dfs_state, + MemoryPool> *pool) { + if (dfs_state) { + dfs_state->~DfsState(); + pool->Free(dfs_state); + } + } + + StateId state_id; // FST state. + ArcIterator arc_iter; // The corresponding arcs. +}; + +} // namespace internal + +// Performs depth-first visitation. Visitor class argument determines actions +// and contains any return data. ArcFilter determines arcs that are considered. +// If 'access_only' is true, performs visitation only to states accessible from +// the initial state. +// +// Note this is similar to Visit() in visit.h called with a LIFO queue, except +// this version has a Visitor class specialized and augmented for a DFS. +template +void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, + bool access_only = false) { + visitor->InitVisit(fst); + const auto start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + // An FST state's DFS status + static constexpr uint8 kDfsWhite = 0; // Undiscovered. + static constexpr uint8 kDfsGrey = 1; // Discovered but unfinished. + static constexpr uint8 kDfsBlack = 2; // Finished. + std::vector state_color; + std::stack *> state_stack; // DFS execution stack. + MemoryPool> state_pool; // Pool for DFSStates. + auto nstates = start + 1; // Number of known states in general case. + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // Tests if expanded case, then + nstates = CountStates(fst); // uses ExpandedFst::NumStates(). + expanded = true; + } + state_color.resize(nstates, kDfsWhite); + StateIterator siter(fst); + // Continue DFS while true. + bool dfs = true; + // Iterate over trees in DFS forest. + for (auto root = start; dfs && root < nstates;) { + state_color[root] = kDfsGrey; + state_stack.push(new (&state_pool) internal::DfsState(fst, root)); + dfs = visitor->InitState(root, root); + while (!state_stack.empty()) { + auto *dfs_state = state_stack.top(); + const auto s = dfs_state->state_id; + if (s >= static_cast(state_color.size())) { + nstates = s + 1; + state_color.resize(nstates, kDfsWhite); + } + ArcIterator &aiter = dfs_state->arc_iter; + if (!dfs || aiter.Done()) { + state_color[s] = kDfsBlack; + internal::DfsState::Destroy(dfs_state, &state_pool); + state_stack.pop(); + if (!state_stack.empty()) { + auto *parent_state = state_stack.top(); + auto &piter = parent_state->arc_iter; + visitor->FinishState(s, parent_state->state_id, &piter.Value()); + piter.Next(); + } else { + visitor->FinishState(s, kNoStateId, nullptr); + } + continue; + } + const auto &arc = aiter.Value(); + if (arc.nextstate >= state_color.size()) { + nstates = arc.nextstate + 1; + state_color.resize(nstates, kDfsWhite); + } + if (!filter(arc)) { + aiter.Next(); + continue; + } + const auto next_color = state_color[arc.nextstate]; + switch (next_color) { + default: + case kDfsWhite: + dfs = visitor->TreeArc(s, arc); + if (!dfs) break; + state_color[arc.nextstate] = kDfsGrey; + state_stack.push(new (&state_pool) + internal::DfsState(fst, arc.nextstate)); + dfs = visitor->InitState(arc.nextstate, root); + break; + case kDfsGrey: + dfs = visitor->BackArc(s, arc); + aiter.Next(); + break; + case kDfsBlack: + dfs = visitor->ForwardOrCrossArc(s, arc); + aiter.Next(); + break; + } + } + if (access_only) break; + // Finds next tree root. + for (root = root == start ? 0 : root + 1; + root < nstates && state_color[root] != kDfsWhite; ++root) { + } + // Checks for a state beyond the largest known state. + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_color.push_back(kDfsWhite); + break; + } + } + } + } + visitor->FinishVisit(); +} + +template +void DfsVisit(const Fst &fst, Visitor *visitor) { + DfsVisit(fst, visitor, AnyArcFilter()); +} + +} // namespace fst + +#endif // FST_DFS_VISIT_H_ diff --git a/projects/llm_framework/include/fst/difference.h b/projects/llm_framework/include/fst/difference.h new file mode 100644 index 00000000..f073b0e1 --- /dev/null +++ b/projects/llm_framework/include/fst/difference.h @@ -0,0 +1,205 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute the difference between two FSAs. + +#ifndef FST_DIFFERENCE_H_ +#define FST_DIFFERENCE_H_ + +#include + + +#include +#include +#include + + +namespace fst { + +template >, + class Filter = SequenceComposeFilter, + class StateTable = + GenericComposeStateTable> +struct DifferenceFstOptions + : public ComposeFstOptions { + explicit DifferenceFstOptions(const CacheOptions &opts = CacheOptions(), + M *matcher1 = nullptr, M *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : ComposeFstOptions(opts, matcher1, matcher2, + filter, state_table) {} +}; + +// Computes the difference between two FSAs. This version is a delayed FST. +// Only strings that are in the first automaton but not in second are retained +// in the result. +// +// The first argument must be an acceptor; the second argument must be an +// unweighted, epsilon-free, deterministic acceptor. One of the arguments must +// be label-sorted. +// +// Complexity: same as ComposeFst. +// +// Caveats: same as ComposeFst. +template +class DifferenceFst : public ComposeFst { + public: + using Arc = A; + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + using ComposeFst::CreateBase1; + + // A - B = A ^ B'. + DifferenceFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ComposeFst(CreateDifferenceImplWithCacheOpts(fst1, fst2, opts)) { + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + template + DifferenceFst( + const Fst &fst1, const Fst &fst2, + const DifferenceFstOptions &opts) + : ComposeFst( + CreateDifferenceImplWithDifferenceOpts(fst1, fst2, opts)) { + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + DifferenceFst(const DifferenceFst &fst, bool safe = false) + : ComposeFst(fst, safe) {} + + // Get a copy of this DifferenceFst. See Fst<>::Copy() for further doc. + DifferenceFst *Copy(bool safe = false) const override { + return new DifferenceFst(*this, safe); + } + + private: + using Impl = internal::ComposeFstImplBase; + using ImplToFst::GetImpl; + + static std::shared_ptr CreateDifferenceImplWithCacheOpts( + const Fst &fst1, const Fst &fst2, const CacheOptions &opts) { + using RM = RhoMatcher>>; + ComplementFst cfst(fst2); + ComposeFstOptions copts( + CacheOptions(), new RM(fst1, MATCH_NONE), + new RM(cfst, MATCH_INPUT, ComplementFst::kRhoLabel)); + return CreateBase1(fst1, cfst, copts); + } + + template + static std::shared_ptr CreateDifferenceImplWithDifferenceOpts( + const Fst &fst1, const Fst &fst2, + const DifferenceFstOptions &opts) { + using RM = RhoMatcher; + ComplementFst cfst(fst2); + ComposeFstOptions copts(opts); + copts.matcher1 = new RM(fst1, MATCH_NONE, kNoLabel, MATCHER_REWRITE_ALWAYS, + opts.matcher1); + copts.matcher2 = new RM(cfst, MATCH_INPUT, ComplementFst::kRhoLabel, + MATCHER_REWRITE_ALWAYS, opts.matcher2); + return CreateBase1(fst1, cfst, copts); + } +}; + +// Specialization for DifferenceFst. +template +class StateIterator> + : public StateIterator> { + public: + explicit StateIterator(const DifferenceFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for DifferenceFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const DifferenceFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +using DifferenceOptions = ComposeOptions; + +// Useful alias when using StdArc. +using StdDifferenceFst = DifferenceFst; + +using DifferenceOptions = ComposeOptions; + +// Computes the difference between two FSAs. This version writes the difference +// to an output MutableFst. Only strings that are in the first automaton but not +// in the second are retained in the result. +// +// The first argument must be an acceptor; the second argument must be an +// unweighted, epsilon-free, deterministic acceptor. One of the arguments must +// be label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template +void Difference(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const DifferenceOptions &opts = DifferenceOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, nopts); + break; + } + case SEQUENCE_FILTER: { + DifferenceFstOptions dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case ALT_SEQUENCE_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case MATCH_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case NO_MATCH_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case NULL_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case TRIVIAL_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_DIFFERENCE_H_ diff --git a/projects/llm_framework/include/fst/disambiguate.h b/projects/llm_framework/include/fst/disambiguate.h new file mode 100644 index 00000000..7e3ba37c --- /dev/null +++ b/projects/llm_framework/include/fst/disambiguate.h @@ -0,0 +1,564 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to disambiguate an FST. + +#ifndef FST_DISAMBIGUATE_H_ +#define FST_DISAMBIGUATE_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +struct DisambiguateOptions : public DeterminizeOptions { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit DisambiguateOptions(float delta = kDelta, + Weight weight = Weight::Zero(), + StateId n = kNoStateId, Label label = 0) + : DeterminizeOptions(delta, std::move(weight), n, label, + DETERMINIZE_FUNCTIONAL) {} +}; + +namespace internal { + +// A determinization filter based on a subset element relation. The relation is +// assumed to be reflexive and symmetric. +template +class RelationDeterminizeFilter { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = IntegerFilterState; + using StateTuple = DeterminizeStateTuple; + using Subset = typename StateTuple::Subset; + using Element = typename StateTuple::Element; + using LabelMap = std::multimap>; + + // This is needed (e.g.) to go into the gallic domain for transducers; there + // is no need to rebind the relation since its use here only depends on the + // state IDs. + template + struct rebind { + using Other = RelationDeterminizeFilter; + }; + + explicit RelationDeterminizeFilter(const Fst &fst) + : fst_(fst.Copy()), r_(new Relation()), s_(kNoStateId), head_(nullptr) {} + + // Ownership of the relation is given to this class. + RelationDeterminizeFilter(const Fst &fst, Relation *r) + : fst_(fst.Copy()), r_(r), s_(kNoStateId), head_(0) {} + + // Ownership of the relation is given to this class. + RelationDeterminizeFilter(const Fst &fst, Relation *r, + std::vector *head) + : fst_(fst.Copy()), r_(r), s_(kNoStateId), head_(head) {} + + // This is needed, e.g., to go into the gallic domain for transducers. + // Ownership of the templated filter argument is given to this class. + template + RelationDeterminizeFilter(const Fst &fst, Filter *filter) + : fst_(fst.Copy()), + r_(new Relation(filter->GetRelation())), + s_(kNoStateId), + head_(filter->GetHeadStates()) { + delete filter; + } + + // Copy constructor; the FST can be passed if it has been deep-copied. + RelationDeterminizeFilter(const RelationDeterminizeFilter &filter, + const Fst *fst = nullptr) + : fst_(fst ? fst->Copy() : filter.fst_->Copy()), + r_(new Relation(*filter.r_)), + s_(kNoStateId), + head_() {} + + FilterState Start() const { return FilterState(fst_->Start()); } + + void SetState(StateId s, const StateTuple &tuple) { + if (s_ != s) { + s_ = s; + tuple_ = &tuple; + const auto head = tuple.filter_state.GetState(); + is_final_ = fst_->Final(head) != Weight::Zero(); + if (head_) { + if (head_->size() <= s) head_->resize(s + 1, kNoStateId); + (*head_)[s] = head; + } + } + } + + // Filters transition, possibly modifying label map. Returns true if arc is + // added to label map. + bool FilterArc(const Arc &arc, const Element &src_element, + const Element &dest_element, LabelMap *label_map) const; + + // Filters super-final transition, returning new final weight. + Weight FilterFinal(const Weight final_weight, const Element &element) const { + return is_final_ ? final_weight : Weight::Zero(); + } + + static uint64 Properties(uint64 props) { + return props & ~(kIDeterministic | kODeterministic); + } + + const Relation &GetRelation() { return *r_; } + + std::vector *GetHeadStates() { return head_; } + + private: + // Pairs arc labels with state tuples with possible heads and empty subsets. + void InitLabelMap(LabelMap *label_map) const; + + std::unique_ptr> fst_; // Input FST. + std::unique_ptr r_; // Relation compatible with inv. trans. fnc. + StateId s_; // Current state. + const StateTuple *tuple_; // Current tuple. + bool is_final_; // Is the current head state final? + std::vector *head_; // Head state for a given state, + // owned by the Disambiguator. +}; + +template +bool RelationDeterminizeFilter::FilterArc( + const Arc &arc, const Element &src_element, const Element &dest_element, + LabelMap *label_map) const { + bool added = false; + if (label_map->empty()) InitLabelMap(label_map); + // Adds element to state tuple if element state is related to tuple head. + for (auto liter = label_map->lower_bound(arc.ilabel); + liter != label_map->end() && liter->first == arc.ilabel; ++liter) { + auto *dest_tuple = liter->second.dest_tuple; + const auto dest_head = dest_tuple->filter_state.GetState(); + if ((*r_)(dest_element.state_id, dest_head)) { + dest_tuple->subset.push_front(dest_element); + added = true; + } + } + return added; +} + +template +void RelationDeterminizeFilter::InitLabelMap( + LabelMap *label_map) const { + const auto src_head = tuple_->filter_state.GetState(); + Label label = kNoLabel; + StateId nextstate = kNoStateId; + for (ArcIterator> aiter(*fst_, src_head); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); + // Continues if multiarc. + if (arc.ilabel == label && arc.nextstate == nextstate) continue; + DeterminizeArc det_arc(arc); + det_arc.dest_tuple->filter_state = FilterState(arc.nextstate); + label_map->insert(std::make_pair(arc.ilabel, det_arc)); + label = arc.ilabel; + nextstate = arc.nextstate; + } +} + +// Helper class to disambiguate an FST via Disambiguate(). +template +class Disambiguator { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // IDs arcs with state ID and arc position. Arc position -1 indicates final + // (super-final transition). + using ArcId = std::pair; + + Disambiguator() : error_(false) {} + + void Disambiguate( + const Fst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts = DisambiguateOptions()) { + VectorFst sfst(ifst); + Connect(&sfst); + ArcSort(&sfst, ArcCompare()); + PreDisambiguate(sfst, ofst, opts); + ArcSort(ofst, ArcCompare()); + FindAmbiguities(*ofst); + RemoveSplits(ofst); + MarkAmbiguities(); + RemoveAmbiguities(ofst); + if (error_) ofst->SetProperties(kError, kError); + } + + private: + // Comparison functor for comparing input labels and next states of arcs. This + // sort order facilitates the predisambiguation. + class ArcCompare { + public: + bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.ilabel < arc2.ilabel || + (arc1.ilabel == arc2.ilabel && arc1.nextstate < arc2.nextstate); + } + + uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kILabelSorted | + (props & kAcceptor ? kOLabelSorted : 0); + } + }; + + // Comparison functor for comparing transitions represented by their arc ID. + // This sort order facilitates ambiguity detection. + class ArcIdCompare { + public: + explicit ArcIdCompare(const std::vector &head) : head_(head) {} + + bool operator()(const ArcId &a1, const ArcId &a2) const { + // Sort first by source head state... + const auto src1 = a1.first; + const auto src2 = a2.first; + const auto head1 = head_[src1]; + const auto head2 = head_[src2]; + if (head1 < head2) return true; + if (head2 < head1) return false; + // ...then by source state... + if (src1 < src2) return true; + if (src2 < src1) return false; + // ...then by position. + return a1.second < a2.second; + } + + private: + const std::vector &head_; + }; + + // A relation that determines if two states share a common future. + class CommonFuture { + public: + using StateTable = GenericComposeStateTable; + using StateTuple = typename StateTable::StateTuple; + + // Needed for compilation with DeterminizeRelationFilter. + CommonFuture() { + FSTERROR() << "Disambiguate::CommonFuture: FST not provided"; + } + + explicit CommonFuture(const Fst &ifst) { + using M = Matcher>; + ComposeFstOptions> opts; + // Ensures composition is between acceptors. + const bool trans = ifst.Properties(kNotAcceptor, true); + const auto *fsa = + trans ? new ProjectFst(ifst, PROJECT_INPUT) : &ifst; + opts.state_table = new StateTable(*fsa, *fsa); + const ComposeFst cfst(*fsa, *fsa, opts); + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, nullptr, &coaccess, &props); + DfsVisit(cfst, &scc_visitor); + for (StateId s = 0; s < coaccess.size(); ++s) { + if (coaccess[s]) { + related_.insert(opts.state_table->Tuple(s).StatePair()); + } + } + if (trans) delete fsa; + } + + bool operator()(const StateId s1, StateId s2) const { + return related_.count(std::make_pair(s1, s2)) > 0; + } + + private: + // States s1 and s2 resp. are in this relation iff they there is a + // path from s1 to a final state that has the same label as some + // path from s2 to a final state. + std::set> related_; + }; + + using ArcIdMap = std::multimap; + + // Inserts candidate into the arc ID map. + inline void InsertCandidate(StateId s1, StateId s2, const ArcId &a1, + const ArcId &a2) { + candidates_->insert(head_[s1] > head_[s2] ? std::make_pair(a1, a2) + : std::make_pair(a2, a1)); + } + + // Returns the arc corresponding to ArcId a. + static Arc GetArc(const Fst &fst, ArcId aid) { + if (aid.second == -1) { // Returns super-final transition. + return Arc(kNoLabel, kNoLabel, fst.Final(aid.first), kNoStateId); + } else { + ArcIterator> aiter(fst, aid.first); + aiter.Seek(aid.second); + return aiter.Value(); + } + } + + // Outputs an equivalent FST whose states are subsets of states that have a + // future path in common. + void PreDisambiguate(const ExpandedFst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts); + + // Finds transitions that are ambiguous candidates in the result of + // PreDisambiguate. + void FindAmbiguities(const ExpandedFst &fst); + + // Finds transition pairs that are ambiguous candidates from two specified + // source states. + void FindAmbiguousPairs(const ExpandedFst &fst, StateId s1, StateId s2); + + // Marks ambiguous transitions to be removed. + void MarkAmbiguities(); + + // Deletes spurious ambiguous transitions (due to quantization). + void RemoveSplits(MutableFst *ofst); + + // Deletes actual ambiguous transitions. + void RemoveAmbiguities(MutableFst *ofst); + + // States s1 and s2 are in this relation iff there is a path from the initial + // state to s1 that has the same label as some path from the initial state to + // s2. We store only state pairs s1, s2 such that s1 <= s2. + std::set> coreachable_; + + // Queue of disambiguation-related states to be processed. We store only + // state pairs s1, s2 such that s1 <= s2. + std::list> queue_; + + // Head state in the pre-disambiguation for a given state. + std::vector head_; + + // Maps from a candidate ambiguous arc A to each ambiguous candidate arc B + // with the same label and destination state as A, whose source state s' is + // coreachable with the source state s of A, and for which head(s') < head(s). + std::unique_ptr candidates_; + + // Set of ambiguous transitions to be removed. + std::set ambiguous_; + + // States to merge due to quantization issues. + std::unique_ptr> merge_; + // Marks error condition. + bool error_; + + Disambiguator(const Disambiguator &) = delete; + Disambiguator &operator=(const Disambiguator &) = delete; +}; + +template +void Disambiguator::PreDisambiguate(const ExpandedFst &ifst, + MutableFst *ofst, + const DisambiguateOptions &opts) { + using CommonDivisor = DefaultCommonDivisor; + using Filter = RelationDeterminizeFilter; + // Subset elements with states s1 and s2 (resp.) are in this relation iff they + // there is a path from s1 to a final state that has the same label as some + // path from s2 to a final state. + auto *common_future = new CommonFuture(ifst); + DeterminizeFstOptions nopts; + nopts.delta = opts.delta; + nopts.subsequential_label = opts.subsequential_label; + nopts.filter = new Filter(ifst, common_future, &head_); + // The filter takes ownership of 'common_future', and determinization takes + // ownership of the filter itself. + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + /* TODO(riley): fails regression test; understand why + if (ifst.Properties(kAcceptor, true)) { + std::vector idistance, odistance; + ShortestDistance(ifst, &idistance, true); + DeterminizeFst dfst(ifst, &idistance, &odistance, nopts); + PruneOptions< Arc, AnyArcFilter> popts(opts.weight_threshold, + opts.state_threshold, + AnyArcFilter(), + &odistance); + Prune(dfst, ofst, popts); + } else */ { + *ofst = DeterminizeFst(ifst, nopts); + Prune(ofst, opts.weight_threshold, opts.state_threshold); + } + } else { + *ofst = DeterminizeFst(ifst, nopts); + } + head_.resize(ofst->NumStates(), kNoStateId); +} + +template +void Disambiguator::FindAmbiguities(const ExpandedFst &fst) { + if (fst.Start() == kNoStateId) return; + candidates_.reset(new ArcIdMap(ArcIdCompare(head_))); + const auto start_pr = std::make_pair(fst.Start(), fst.Start()); + coreachable_.insert(start_pr); + queue_.push_back(start_pr); + while (!queue_.empty()) { + const auto &pr = queue_.front(); + const auto s1 = pr.first; + const auto s2 = pr.second; + queue_.pop_front(); + FindAmbiguousPairs(fst, s1, s2); + } +} + +template +void Disambiguator::FindAmbiguousPairs(const ExpandedFst &fst, + StateId s1, StateId s2) { + if (fst.NumArcs(s2) > fst.NumArcs(s1)) FindAmbiguousPairs(fst, s2, s1); + SortedMatcher> matcher(fst, MATCH_INPUT); + matcher.SetState(s2); + for (ArcIterator> aiter(fst, s1); !aiter.Done(); aiter.Next()) { + const auto &arc1 = aiter.Value(); + const ArcId a1(s1, aiter.Position()); + if (matcher.Find(arc1.ilabel)) { + for (; !matcher.Done(); matcher.Next()) { + const auto &arc2 = matcher.Value(); + // Continues on implicit epsilon match. + if (arc2.ilabel == kNoLabel) continue; + const ArcId a2(s2, matcher.Position()); + // Actual transition is ambiguous. + if (s1 != s2 && arc1.nextstate == arc2.nextstate) { + InsertCandidate(s1, s2, a1, a2); + } + const auto spr = arc1.nextstate <= arc2.nextstate + ? std::make_pair(arc1.nextstate, arc2.nextstate) + : std::make_pair(arc2.nextstate, arc1.nextstate); + // Not already marked as coreachable? + if (coreachable_.insert(spr).second) { + // Only possible if state split by quantization issues. + if (spr.first != spr.second && + head_[spr.first] == head_[spr.second]) { + if (!merge_) { + merge_.reset(new UnionFind(fst.NumStates(), kNoStateId)); + merge_->MakeAllSet(fst.NumStates()); + } + merge_->Union(spr.first, spr.second); + } else { + queue_.push_back(spr); + } + } + } + } + } + // Super-final transition is ambiguous. + if (s1 != s2 && fst.Final(s1) != Weight::Zero() && + fst.Final(s2) != Weight::Zero()) { + const ArcId a1(s1, -1); + const ArcId a2(s2, -1); + InsertCandidate(s1, s2, a1, a2); + } +} + +template +void Disambiguator::MarkAmbiguities() { + if (!candidates_) return; + for (auto it = candidates_->begin(); it != candidates_->end(); ++it) { + const auto a = it->first; + const auto b = it->second; + // If b is not to be removed, then a is. + if (ambiguous_.count(b) == 0) ambiguous_.insert(a); + } + coreachable_.clear(); + candidates_.reset(); +} + +template +void Disambiguator::RemoveSplits(MutableFst *ofst) { + if (!merge_) return; + // Merges split states to remove spurious ambiguities. + for (StateIterator> siter(*ofst); !siter.Done(); + siter.Next()) { + for (MutableArcIterator> aiter(ofst, siter.Value()); + !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + const auto nextstate = merge_->FindSet(arc.nextstate); + if (nextstate != arc.nextstate) { + arc.nextstate = nextstate; + aiter.SetValue(arc); + } + } + } + // Repeats search for actual ambiguities on modified FST. + coreachable_.clear(); + merge_.reset(); + candidates_.reset(); + FindAmbiguities(*ofst); + if (merge_) { // Shouldn't get here; sanity test. + FSTERROR() << "Disambiguate: Unable to remove spurious ambiguities"; + error_ = true; + return; + } +} + +template +void Disambiguator::RemoveAmbiguities(MutableFst *ofst) { + if (ambiguous_.empty()) return; + // Adds dead state to redirect ambiguous transitions to be removed. + const auto dead = ofst->AddState(); + for (auto it = ambiguous_.begin(); it != ambiguous_.end(); ++it) { + const auto pos = it->second; + if (pos >= 0) { // Actual transition. + MutableArcIterator> aiter(ofst, it->first); + aiter.Seek(pos); + auto arc = aiter.Value(); + arc.nextstate = dead; + aiter.SetValue(arc); + } else { // Super-final transition. + ofst->SetFinal(it->first, Weight::Zero()); + } + } + Connect(ofst); + ambiguous_.clear(); +} + +} // namespace internal + +// Disambiguates a weighted FST. This version writes the disambiguated FST to an +// output MutableFst. The result will be an equivalent FST that has the +// property that there are not two distinct paths from the initial state to a +// final state with the same input labeling. +// +// The weights must be (weakly) left divisible (valid for Tropical and +// LogWeight). +// +// Complexity: +// +// Disambiguable: exponential (polynomial in the size of the output). +// Non-disambiguable: does not terminate. +// +// The disambiguable transducers include all automata and functional transducers +// that are unweighted or that are acyclic or that are unambiguous. +// +// For more information, see: +// +// Mohri, M. and Riley, M. 2015. On the disambiguation of weighted automata. +// In CIAA, pages 263-278. +template +void Disambiguate( + const Fst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts = DisambiguateOptions()) { + internal::Disambiguator disambiguator; + disambiguator.Disambiguate(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_DISAMBIGUATE_H_ diff --git a/projects/llm_framework/include/fst/edit-fst.h b/projects/llm_framework/include/fst/edit-fst.h new file mode 100644 index 00000000..e7bdfd61 --- /dev/null +++ b/projects/llm_framework/include/fst/edit-fst.h @@ -0,0 +1,702 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation that allows non-destructive edit operations on an +// existing FST. +// +// The EditFst class enables non-destructive edit operations on a wrapped +// ExpandedFst. The implementation uses copy-on-write semantics at the node +// level: if a user has an underlying fst on which he or she wants to perform a +// relatively small number of edits (read: mutations), then this implementation +// will copy the edited node to an internal MutableFst and perform any edits in +// situ on that copied node. This class supports all the methods of MutableFst +// except for DeleteStates(const std::vector &); thus, new nodes may +// also be +// added, and one may add transitions from existing nodes of the wrapped fst to +// new nodes. +// +// N.B.: The documentation for Fst::Copy(true) says that its behavior is +// undefined if invoked on an fst that has already been accessed. This class +// requires that the Fst implementation it wraps provides consistent, reliable +// behavior when its Copy(true) method is invoked, where consistent means +// the graph structure, graph properties and state numbering and do not change. +// VectorFst and CompactFst, for example, are both well-behaved in this regard. + +#ifndef FST_EDIT_FST_H_ +#define FST_EDIT_FST_H_ + +#include +#include +#include + +#include + +#include + + +namespace fst { +namespace internal { + +// The EditFstData class is a container for all mutable data for EditFstImpl; +// also, this class provides most of the actual implementation of what EditFst +// does (that is, most of EditFstImpl's methods delegate to methods in this, the +// EditFstData class). Instances of this class are reference-counted and can be +// shared between otherwise independent EditFstImpl instances. This scheme +// allows EditFstImpl to implement the thread-safe, copy-on-write semantics +// required by Fst::Copy(true). +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstData instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template , + typename MutableFstT = VectorFst> +class EditFstData { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + EditFstData() : num_new_states_(0) {} + + EditFstData(const EditFstData &other) + : edits_(other.edits_), + external_to_internal_ids_(other.external_to_internal_ids_), + edited_final_weights_(other.edited_final_weights_), + num_new_states_(other.num_new_states_) {} + + ~EditFstData() {} + + static EditFstData *Read( + std::istream &strm, const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + // Serialize all private data members of this class. + FstWriteOptions edits_opts(opts); + edits_opts.write_header = true; // Force writing contained header. + edits_.Write(strm, edits_opts); + WriteType(strm, external_to_internal_ids_); + WriteType(strm, edited_final_weights_); + WriteType(strm, num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFstData::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + StateId NumNewStates() const { return num_new_states_; } + + // accessor methods for the fst holding edited states + StateId EditedStart() const { return edits_.Start(); } + + Weight Final(StateId s, const WrappedFstT *wrapped) const { + auto final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->Final(s) + : edits_.Final(it->second); + } else { + return final_weight_it->second; + } + } + + size_t NumArcs(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumArcs(s) + : edits_.NumArcs(it->second); + } + + size_t NumInputEpsilons(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumInputEpsilons(s) + : edits_.NumInputEpsilons(it->second); + } + + size_t NumOutputEpsilons(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumOutputEpsilons(s) + : edits_.NumOutputEpsilons(it->second); + } + + void SetEditedProperties(uint64 props, uint64 mask) { + edits_.SetProperties(props, mask); + } + + // Non-const MutableFst operations. + + // Sets the start state for this FST. + void SetStart(StateId s) { edits_.SetStart(s); } + + // Sets the final state for this FST. + Weight SetFinal(StateId s, Weight w, const WrappedFstT *wrapped) { + Weight old_weight = Final(s, wrapped); + auto it = GetEditedIdMapIterator(s); + // If we haven't already edited state s, don't add it to edited_ (which can + // be expensive if s has many transitions); just use the + // edited_final_weights_ map. + if (it == NotInEditedMap()) { + edited_final_weights_[s] = w; + } else { + edits_.SetFinal(GetEditableInternalId(s, wrapped), w); + } + return old_weight; + } + + // Adds a new state to this FST, initially with no arcs. + StateId AddState(StateId curr_num_states) { + StateId internal_state_id = edits_.AddState(); + StateId external_state_id = curr_num_states; + external_to_internal_ids_[external_state_id] = internal_state_id; + num_new_states_++; + return external_state_id; + } + + // Adds the specified arc to the specified state of this FST. + const Arc *AddArc(StateId s, const Arc &arc, const WrappedFstT *wrapped) { + const auto internal_id = GetEditableInternalId(s, wrapped); + const auto num_arcs = edits_.NumArcs(internal_id); + ArcIterator arc_it(edits_, internal_id); + const Arc *prev_arc = nullptr; + if (num_arcs > 0) { + // grab the final arc associated with this state in edits_ + arc_it.Seek(num_arcs - 1); + prev_arc = &(arc_it.Value()); + } + edits_.AddArc(internal_id, arc); + return prev_arc; + } + + void DeleteStates() { + edits_.DeleteStates(); + num_new_states_ = 0; + external_to_internal_ids_.clear(); + edited_final_weights_.clear(); + } + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped), n); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped)); + } + + // End methods for non-const MutableFst operations. + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data, + const WrappedFstT *wrapped) const { + auto id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + VLOG(3) << "EditFstData::InitArcIterator: iterating on state " << s + << " of original fst"; + wrapped->InitArcIterator(s, data); + } else { + VLOG(2) << "EditFstData::InitArcIterator: iterating on edited state " << s + << " (internal state id: " << id_map_it->second << ")"; + edits_.InitArcIterator(id_map_it->second, data); + } + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData *data, + const WrappedFstT *wrapped) { + data->base = new MutableArcIterator( + &edits_, GetEditableInternalId(s, wrapped)); + } + + // Prints out the map from external to internal state id's (for debugging + // purposes). + void PrintMap() { + for (auto map_it = external_to_internal_ids_.begin(); + map_it != NotInEditedMap(); ++map_it) { + LOG(INFO) << "(external,internal)=(" << map_it->first << "," + << map_it->second << ")"; + } + } + + private: + // Returns the iterator of the map from external to internal state id's + // of edits_ for the specified external state id. + typename std::unordered_map::const_iterator + GetEditedIdMapIterator(StateId s) const { + return external_to_internal_ids_.find(s); + } + + typename std::unordered_map::const_iterator + NotInEditedMap() const { + return external_to_internal_ids_.end(); + } + + typename std::unordered_map::const_iterator + GetFinalWeightIterator(StateId s) const { + return edited_final_weights_.find(s); + } + + typename std::unordered_map::const_iterator + NotInFinalWeightMap() const { + return edited_final_weights_.end(); + } + + // Returns the internal state ID of the specified external ID if the state has + // already been made editable, or else copies the state from wrapped_ to + // edits_ and returns the state id of the newly editable state in edits_. + StateId GetEditableInternalId(StateId s, const WrappedFstT *wrapped) { + auto id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + StateId new_internal_id = edits_.AddState(); + VLOG(2) << "EditFstData::GetEditableInternalId: editing state " << s + << " of original fst; new internal state id:" << new_internal_id; + external_to_internal_ids_[s] = new_internal_id; + for (ArcIterator> arc_iterator(*wrapped, s); + !arc_iterator.Done(); arc_iterator.Next()) { + edits_.AddArc(new_internal_id, arc_iterator.Value()); + } + // Copies the final weight. + auto final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + edits_.SetFinal(new_internal_id, wrapped->Final(s)); + } else { + edits_.SetFinal(new_internal_id, final_weight_it->second); + edited_final_weights_.erase(s); + } + return new_internal_id; + } else { + return id_map_it->second; + } + } + + // A mutable FST (by default, a VectorFst) to contain new states, and/or + // copies of states from a wrapped ExpandedFst that have been modified in + // some way. + MutableFstT edits_; + // A mapping from external state IDs to the internal IDs of states that + // appear in edits_. + std::unordered_map external_to_internal_ids_; + // A mapping from external state IDs to final state weights assigned to + // those states. The states in this map are *only* those whose final weight + // has been modified; if any other part of the state has been modified, + // the entire state is copied to edits_, and all modifications reside there. + std::unordered_map edited_final_weights_; + // The number of new states added to this mutable fst impl, which is <= the + // number of states in edits_ (since edits_ contains both edited *and* new + // states). + StateId num_new_states_; +}; + +// EditFstData method implementations: just the Read method. +template +EditFstData * +EditFstData::Read(std::istream &strm, + const FstReadOptions &opts) { + auto *data = new EditFstData(); + // next read in MutabelFstT machine that stores edits + FstReadOptions edits_opts(opts); + // Contained header was written out, so read it in. + edits_opts.header = nullptr; + + // Because our internal representation of edited states is a solid object + // of type MutableFstT (defaults to VectorFst) and not a pointer, + // and because the static Read method allocates a new object on the heap, + // we need to call Read, check if there was a failure, use + // MutableFstT::operator= to assign the object (not the pointer) to the + // edits_ data member (which will increase the ref count by 1 on the impl) + // and, finally, delete the heap-allocated object. + std::unique_ptr edits(MutableFstT::Read(strm, edits_opts)); + if (!edits) return nullptr; + data->edits_ = *edits; + edits.reset(); + // Finally, reads in rest of private data members. + ReadType(strm, &data->external_to_internal_ids_); + ReadType(strm, &data->edited_final_weights_); + ReadType(strm, &data->num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFst::Read: read failed: " << opts.source; + return nullptr; + } + return data; +} + +// This class enables non-destructive edit operations on a wrapped ExpandedFst. +// The implementation uses copy-on-write semantics at the node level: if a user +// has an underlying fst on which he or she wants to perform a relatively small +// number of edits (read: mutations), then this implementation will copy the +// edited node to an internal MutableFst and perform any edits in situ on that +// copied node. This class supports all the methods of MutableFst except for +// DeleteStates(const std::vector &); thus, new nodes may also be +// added, and +// one may add transitions from existing nodes of the wrapped fst to new nodes. +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstImpl instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template , + typename MutableFstT = VectorFst> +class EditFstImpl : public FstImpl { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + // Constructs an editable FST implementation with no states. Effectively, this + // initially-empty fst will in every way mimic the behavior of a + // VectorFst---more precisely, a VectorFstImpl instance---but with slightly + // slower performance (by a constant factor), due to the fact that + // this class maintains a mapping between external state id's and + // their internal equivalents. + EditFstImpl() : wrapped_(new MutableFstT()) { + FstImpl::SetType("edit"); + InheritPropertiesFromWrapped(); + data_ = std::make_shared>(); + } + + // Wraps the specified ExpandedFst. This constructor requires that the + // specified Fst is an ExpandedFst instance. This requirement is only enforced + // at runtime. (See below for the reason.) + // + // This library uses the pointer-to-implementation or "PIMPL" design pattern. + // In particular, to make it convenient to bind an implementation class to its + // interface, there are a pair of template "binder" classes, one for immutable + // and one for mutable fst's (ImplToFst and ImplToMutableFst, respectively). + // As it happens, the API for the ImplToMutableFst class requires that + // the implementation class--the template parameter "I"--have a constructor + // taking a const Fst reference. Accordingly, the constructor here must + // perform a static_cast to the WrappedFstT type required by EditFst and + // therefore EditFstImpl. + explicit EditFstImpl(const Fst &wrapped) + : wrapped_(static_cast(wrapped.Copy())) { + FstImpl::SetType("edit"); + data_ = std::make_shared>(); + // have edits_ inherit all properties from wrapped_ + data_->SetEditedProperties(wrapped_->Properties(kFstProperties, false), + kFstProperties); + InheritPropertiesFromWrapped(); + } + + // A copy constructor for this implementation class, used to implement + // the Copy() method of the Fst interface. + EditFstImpl(const EditFstImpl &impl) + : FstImpl(), + wrapped_(static_cast(impl.wrapped_->Copy(true))), + data_(impl.data_) { + SetProperties(impl.Properties()); + } + + // const Fst/ExpandedFst operations, declared in the Fst and ExpandedFst + // interfaces + StateId Start() const { + const auto edited_start = data_->EditedStart(); + return edited_start == kNoStateId ? wrapped_->Start() : edited_start; + } + + Weight Final(StateId s) const { return data_->Final(s, wrapped_.get()); } + + size_t NumArcs(StateId s) const { return data_->NumArcs(s, wrapped_.get()); } + + size_t NumInputEpsilons(StateId s) const { + return data_->NumInputEpsilons(s, wrapped_.get()); + } + + size_t NumOutputEpsilons(StateId s) const { + return data_->NumOutputEpsilons(s, wrapped_.get()); + } + + StateId NumStates() const { + return wrapped_->NumStates() + data_->NumNewStates(); + } + + static EditFstImpl *Read( + std::istream &strm, const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(NumStates()); + FstWriteOptions header_opts(opts); + // Allows the contained FST to hold any symbols. + header_opts.write_isymbols = false; + header_opts.write_osymbols = false; + WriteHeader(strm, header_opts, kFileVersion, &hdr); + // First, serializes the wrapped FST to stream. + FstWriteOptions wrapped_opts(opts); + // Forcse writing the contained header. + wrapped_opts.write_header = true; + wrapped_->Write(strm, wrapped_opts); + data_->Write(strm, opts); + strm.flush(); + if (!strm) { + LOG(ERROR) << "EditFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + // Sets the start state for this FST. + void SetStart(StateId s) { + MutateCheck(); + data_->SetStart(s); + SetProperties(SetStartProperties(FstImpl::Properties())); + } + + // Sets the final state for this fst. + void SetFinal(StateId s, Weight weight) { + MutateCheck(); + Weight old_weight = data_->SetFinal(s, weight, wrapped_.get()); + SetProperties( + SetFinalProperties(FstImpl::Properties(), old_weight, weight)); + } + + // Adds a new state to this fst, initially with no arcs. + StateId AddState() { + MutateCheck(); + SetProperties(AddStateProperties(FstImpl::Properties())); + return data_->AddState(NumStates()); + } + + // Adds the specified arc to the specified state of this fst. + void AddArc(StateId s, const Arc &arc) { + MutateCheck(); + const auto *prev_arc = data_->AddArc(s, arc, wrapped_.get()); + SetProperties( + AddArcProperties(FstImpl::Properties(), s, arc, prev_arc)); + } + + void DeleteStates(const std::vector &dstates) { + FSTERROR() << ": EditFstImpl::DeleteStates(const std::vector&): " + << " not implemented"; + SetProperties(kError, kError); + } + + // Deletes all states in this fst. + void DeleteStates(); + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n) { + MutateCheck(); + data_->DeleteArcs(s, n, wrapped_.get()); + SetProperties(DeleteArcsProperties(FstImpl::Properties())); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s) { + MutateCheck(); + data_->DeleteArcs(s, wrapped_.get()); + SetProperties(DeleteArcsProperties(FstImpl::Properties())); + } + + void ReserveStates(StateId s) {} + + void ReserveArcs(StateId s, size_t n) {} + + // Ends non-const MutableFst operations. + + // Provides information for the generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = NumStates(); + } + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + data_->InitArcIterator(s, data, wrapped_.get()); + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData *data) { + MutateCheck(); + data_->InitMutableArcIterator(s, data, wrapped_.get()); + } + + private: + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded | kMutable; + // Current file format version. + static constexpr int kFileVersion = 2; + // Minimum file format version supported + static constexpr int kMinFileVersion = 2; + + // Causes this FST to inherit all the properties from its wrapped FST, except + // for the two properties that always apply to EditFst instances: kExpanded + // and kMutable. + void InheritPropertiesFromWrapped() { + SetProperties(wrapped_->Properties(kCopyProperties, false) | + kStaticProperties); + SetInputSymbols(wrapped_->InputSymbols()); + SetOutputSymbols(wrapped_->OutputSymbols()); + } + + // This method ensures that any operations that alter the mutable data + // portion of this EditFstImpl cause the data_ member to be copied when its + // reference count is greater than 1. Note that this method is distinct from + // MutableFst::Mutate, which gets invoked whenever one of the basic mutation + // methods defined in MutableFst is invoked, such as SetInputSymbols. + // The MutateCheck here in EditFstImpl is invoked whenever one of the + // mutating methods specifically related to the types of edits provided + // by EditFst is performed, such as changing an arc of an existing state + // of the wrapped fst via a MutableArcIterator, or adding a new state via + // AddState(). + void MutateCheck() { + if (!data_.unique()) { + data_ = + std::make_shared>(*data_); + } + } + + // The FST that this FST wraps. The purpose of this class is to enable + // non-destructive edits on this wrapped FST. + std::unique_ptr wrapped_; + // The mutable data for this EditFst instance, with delegates for all the + // methods that can mutate data. + std::shared_ptr> data_; +}; + +template +constexpr uint64 EditFstImpl::kStaticProperties; + +template +constexpr int EditFstImpl::kFileVersion; + +template +constexpr int EditFstImpl::kMinFileVersion; + +template +inline void EditFstImpl::DeleteStates() { + data_->DeleteStates(); + // we are deleting all states, so just forget about pointer to wrapped_ + // and do what default constructor does: set wrapped_ to a new VectorFst + wrapped_.reset(new MutableFstT()); + const auto new_props = + DeleteAllStatesProperties(FstImpl::Properties(), kStaticProperties); + FstImpl::SetProperties(new_props); +} + +template +EditFstImpl * +EditFstImpl::Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = new EditFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->SetStart(hdr.Start()); + // Reads in wrapped FST. + FstReadOptions wrapped_opts(opts); + // Contained header was written out, so reads it in too. + wrapped_opts.header = nullptr; + std::unique_ptr> wrapped_fst(Fst::Read(strm, wrapped_opts)); + if (!wrapped_fst) return nullptr; + impl->wrapped_.reset(static_cast(wrapped_fst.release())); + impl->data_ = std::shared_ptr>( + EditFstData::Read(strm, opts)); + if (!impl->data_) return nullptr; + return impl; +} + +} // namespace internal + +// Concrete, editable FST. This class attaches interface to implementation. +template , + typename MutableFstT = VectorFst> +class EditFst : public ImplToMutableFst< + internal::EditFstImpl> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::EditFstImpl; + + friend class MutableArcIterator>; + + EditFst() : ImplToMutableFst(std::make_shared()) {} + + explicit EditFst(const Fst &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + explicit EditFst(const WrappedFstT &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + // See Fst<>::Copy() for doc. + EditFst(const EditFst &fst, bool safe = false) + : ImplToMutableFst(fst, safe) {} + + ~EditFst() override {} + + // Gets a copy of this EditFst. See Fst<>::Copy() for further doc. + EditFst *Copy( + bool safe = false) const override { + return new EditFst(*this, safe); + } + + EditFst &operator=( + const EditFst &fst) { + SetImpl(fst.GetSharedImpl()); + return *this; + } + + EditFst &operator=( + const Fst &fst) override { + SetImpl(std::make_shared(fst)); + return *this; + } + + // Reads an EditFst from an input stream, returning nullptr on error. + static EditFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new EditFst(std::shared_ptr(impl)) : nullptr; + } + + // Reads an EditFst from a file, returning nullptr on error. If the filename + // argument is an empty string, it reads from standard input. + static EditFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst>::Read(filename); + return impl ? new EditFst( + std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + void InitMutableArcIterator(StateId s, + MutableArcIteratorData *data) override { + GetMutableImpl()->InitMutableArcIterator(s, data); + } + + private: + explicit EditFst(std::shared_ptr impl) : ImplToMutableFst(impl) {} + + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; + using ImplToFst>::SetImpl; +}; + +} // namespace fst + +#endif // FST_EDIT_FST_H_ diff --git a/projects/llm_framework/include/fst/encode.h b/projects/llm_framework/include/fst/encode.h new file mode 100644 index 00000000..f251bbfc --- /dev/null +++ b/projects/llm_framework/include/fst/encode.h @@ -0,0 +1,556 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to encode and decode an FST. + +#ifndef FST_ENCODE_H_ +#define FST_ENCODE_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + + +namespace fst { + +enum EncodeType { ENCODE = 1, DECODE = 2 }; + +static constexpr uint32 kEncodeLabels = 0x0001; +static constexpr uint32 kEncodeWeights = 0x0002; +static constexpr uint32 kEncodeFlags = 0x0003; + +namespace internal { + +static constexpr uint32 kEncodeHasISymbols = 0x0004; +static constexpr uint32 kEncodeHasOSymbols = 0x0008; + +// Identifies stream data as an encode table (and its endianity) +static const int32 kEncodeMagicNumber = 2129983209; + +// The following class encapsulates implementation details for the encoding and +// decoding of label/weight tuples used for encoding and decoding of FSTs. The +// EncodeTable is bidirectional. I.e, it stores both the Tuple of encode labels +// and weights to a unique label, and the reverse. +template +class EncodeTable { + public: + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + // Encoded data consists of arc input/output labels and arc weight. + struct Tuple { + Tuple() {} + + Tuple(Label ilabel_, Label olabel_, Weight weight_) + : ilabel(ilabel_), olabel(olabel_), weight(std::move(weight_)) {} + + Tuple(const Tuple &tuple) + : ilabel(tuple.ilabel), + olabel(tuple.olabel), + weight(std::move(tuple.weight)) {} + + Label ilabel; + Label olabel; + Weight weight; + }; + + // Comparison object for hashing EncodeTable Tuple(s). + class TupleEqual { + public: + bool operator()(const Tuple *x, const Tuple *y) const { + return (x->ilabel == y->ilabel && x->olabel == y->olabel && + x->weight == y->weight); + } + }; + + // Hash function for EncodeTabe Tuples. Based on the encode flags + // we either hash the labels, weights or combination of them. + class TupleKey { + public: + TupleKey() : encode_flags_(kEncodeLabels | kEncodeWeights) {} + + TupleKey(const TupleKey &key) : encode_flags_(key.encode_flags_) {} + + explicit TupleKey(uint32 encode_flags) : encode_flags_(encode_flags) {} + + size_t operator()(const Tuple *x) const { + size_t hash = x->ilabel; + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5; + if (encode_flags_ & kEncodeLabels) { + hash = hash << lshift ^ hash >> rshift ^ x->olabel; + } + if (encode_flags_ & kEncodeWeights) { + hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); + } + return hash; + } + + private: + int32 encode_flags_; + }; + + explicit EncodeTable(uint32 encode_flags) + : flags_(encode_flags), encode_hash_(1024, TupleKey(encode_flags)) {} + + using EncodeHash = std::unordered_map; + + // Given an arc, encodes either input/output labels or input/costs or both. + Label Encode(const Arc &arc) { + std::unique_ptr tuple( + new Tuple(arc.ilabel, flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One())); + auto insert_result = encode_hash_.insert( + std::make_pair(tuple.get(), encode_tuples_.size() + 1)); + if (insert_result.second) encode_tuples_.push_back(std::move(tuple)); + return insert_result.first->second; + } + + // Given an arc, looks up its encoded label or returns kNoLabel if not found. + Label GetLabel(const Arc &arc) const { + const Tuple tuple(arc.ilabel, flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + auto it = encode_hash_.find(&tuple); + return (it == encode_hash_.end()) ? kNoLabel : it->second; + } + + // Given an encoded arc label, decodes back to input/output labels and costs. + const Tuple *Decode(Label key) const { + if (key < 1 || key > encode_tuples_.size()) { + LOG(ERROR) << "EncodeTable::Decode: Unknown decode key: " << key; + return nullptr; + } + return encode_tuples_[key - 1].get(); + } + + size_t Size() const { return encode_tuples_.size(); } + + bool Write(std::ostream &strm, const string &source) const; + + static EncodeTable *Read(std::istream &strm, const string &source); + + uint32 Flags() const { return flags_ & kEncodeFlags; } + + const SymbolTable *InputSymbols() const { return isymbols_.get(); } + + const SymbolTable *OutputSymbols() const { return osymbols_.get(); } + + void SetInputSymbols(const SymbolTable *syms) { + if (syms) { + isymbols_.reset(syms->Copy()); + flags_ |= kEncodeHasISymbols; + } else { + isymbols_.reset(); + flags_ &= ~kEncodeHasISymbols; + } + } + + void SetOutputSymbols(const SymbolTable *syms) { + if (syms) { + osymbols_.reset(syms->Copy()); + flags_ |= kEncodeHasOSymbols; + } else { + osymbols_.reset(); + flags_ &= ~kEncodeHasOSymbols; + } + } + + private: + uint32 flags_; + std::vector> encode_tuples_; + EncodeHash encode_hash_; + std::unique_ptr isymbols_; // Pre-encoded input symbol table. + std::unique_ptr osymbols_; // Pre-encoded output symbol table. + + EncodeTable(const EncodeTable &) = delete; + EncodeTable &operator=(const EncodeTable &) = delete; +}; + +template +bool EncodeTable::Write(std::ostream &strm, + const string &source) const { + WriteType(strm, kEncodeMagicNumber); + WriteType(strm, flags_); + const int64 size = encode_tuples_.size(); + WriteType(strm, size); + for (const auto &tuple : encode_tuples_) { + WriteType(strm, tuple->ilabel); + WriteType(strm, tuple->olabel); + tuple->weight.Write(strm); + } + if (flags_ & kEncodeHasISymbols) isymbols_->Write(strm); + if (flags_ & kEncodeHasOSymbols) osymbols_->Write(strm); + strm.flush(); + if (!strm) { + LOG(ERROR) << "EncodeTable::Write: Write failed: " << source; + return false; + } + return true; +} + +template +EncodeTable *EncodeTable::Read(std::istream &strm, + const string &source) { + int32 magic_number = 0; + ReadType(strm, &magic_number); + if (magic_number != kEncodeMagicNumber) { + LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; + return nullptr; + } + uint32 flags; + ReadType(strm, &flags); + int64 size; + ReadType(strm, &size); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: Read failed: " << source; + return nullptr; + } + std::unique_ptr> table(new EncodeTable(flags)); + for (int64 i = 0; i < size; ++i) { + std::unique_ptr tuple(new Tuple()); + ReadType(strm, &tuple->ilabel); + ReadType(strm, &tuple->olabel); + tuple->weight.Read(strm); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: Read failed: " << source; + return nullptr; + } + table->encode_tuples_.push_back(std::move(tuple)); + table->encode_hash_[table->encode_tuples_.back().get()] = + table->encode_tuples_.size(); + } + if (flags & kEncodeHasISymbols) { + table->isymbols_.reset(SymbolTable::Read(strm, source)); + } + if (flags & kEncodeHasOSymbols) { + table->osymbols_.reset(SymbolTable::Read(strm, source)); + } + return table.release(); +} + +} // namespace internal + +// A mapper to encode/decode weighted transducers. Encoding of an FST is used +// for performing classical determinization or minimization on a weighted +// transducer viewing it as an unweighted acceptor over encoded labels. +// +// The mapper stores the encoding in a local hash table (EncodeTable). This +// table is shared (and reference-counted) between the encoder and decoder. +// A decoder has read-only access to the EncodeTable. +// +// The EncodeMapper allows on the fly encoding of the machine. As the +// EncodeTable is generated the same table may by used to decode the machine +// on the fly. For example in the following sequence of operations +// +// Encode -> Determinize -> Decode +// +// we will use the encoding table generated during the encode step in the +// decode, even though the encoding is not complete. +template +class EncodeMapper { + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + public: + EncodeMapper(uint32 flags, EncodeType type) + : flags_(flags), + type_(type), + table_(std::make_shared>(flags)), + error_(false) {} + + EncodeMapper(const EncodeMapper &mapper) + : flags_(mapper.flags_), + type_(mapper.type_), + table_(mapper.table_), + error_(false) {} + + // Copy constructor but setting the type, typically to DECODE. + EncodeMapper(const EncodeMapper &mapper, EncodeType type) + : flags_(mapper.flags_), + type_(type), + table_(mapper.table_), + error_(mapper.error_) {} + + Arc operator()(const Arc &arc); + + MapFinalAction FinalAction() const { + return (type_ == ENCODE && (flags_ & kEncodeWeights)) + ? MAP_REQUIRE_SUPERFINAL + : MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) { + uint64 outprops = inprops; + if (error_) outprops |= kError; + uint64 mask = kFstProperties; + if (flags_ & kEncodeLabels) { + mask &= kILabelInvariantProperties & kOLabelInvariantProperties; + } + if (flags_ & kEncodeWeights) { + mask &= kILabelInvariantProperties & kWeightInvariantProperties & + (type_ == ENCODE ? kAddSuperFinalProperties + : kRmSuperFinalProperties); + } + return outprops & mask; + } + + uint32 Flags() const { return flags_; } + + EncodeType Type() const { return type_; } + + bool Write(std::ostream &strm, const string &source) const { + return table_->Write(strm, source); + } + + bool Write(const string &filename) const { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return false; + } + return Write(strm, filename); + } + + static EncodeMapper *Read(std::istream &strm, const string &source, + EncodeType type = ENCODE) { + auto *table = internal::EncodeTable::Read(strm, source); + return table ? new EncodeMapper(table->Flags(), type, table) : nullptr; + } + + static EncodeMapper *Read(const string &filename, + EncodeType type = ENCODE) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return nullptr; + } + return Read(strm, filename, type); + } + + const SymbolTable *InputSymbols() const { return table_->InputSymbols(); } + + const SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable *syms) { + table_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) { + table_->SetOutputSymbols(syms); + } + + private: + uint32 flags_; + EncodeType type_; + std::shared_ptr> table_; + bool error_; + + explicit EncodeMapper(uint32 flags, EncodeType type, + internal::EncodeTable *table) + : flags_(flags), type_(type), table_(table), error_(false) {} + + EncodeMapper &operator=(const EncodeMapper &) = delete; +}; + +template +Arc EncodeMapper::operator()(const Arc &arc) { + if (type_ == ENCODE) { + if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || + (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && + arc.weight == Weight::Zero())) { + return arc; + } else { + const auto label = table_->Encode(arc); + return Arc(label, flags_ & kEncodeLabels ? label : arc.olabel, + flags_ & kEncodeWeights ? Weight::One() : arc.weight, + arc.nextstate); + } + } else { // type_ == DECODE + if (arc.nextstate == kNoStateId) { + return arc; + } else { + if (arc.ilabel == 0) return arc; + if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { + FSTERROR() << "EncodeMapper: Label-encoded arc has different " + "input and output labels"; + error_ = true; + } + if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { + FSTERROR() << "EncodeMapper: Weight-encoded arc has non-trivial weight"; + error_ = true; + } + const auto tuple = table_->Decode(arc.ilabel); + if (!tuple) { + FSTERROR() << "EncodeMapper: Decode failed"; + error_ = true; + return Arc(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); + } else { + return Arc(tuple->ilabel, + flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, + flags_ & kEncodeWeights ? tuple->weight : arc.weight, + arc.nextstate); + } + } + } +} + +// Complexity: O(E + V). +template +inline void Encode(MutableFst *fst, EncodeMapper *mapper) { + mapper->SetInputSymbols(fst->InputSymbols()); + mapper->SetOutputSymbols(fst->OutputSymbols()); + ArcMap(fst, mapper); +} + +template +inline void Decode(MutableFst *fst, const EncodeMapper &mapper) { + ArcMap(fst, EncodeMapper(mapper, DECODE)); + RmFinalEpsilon(fst); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + +// On-the-fly encoding of an input FST. +// +// Complexity: +// +// Construction: O(1) +// Traversal: O(e + v) +// +// where e is the number of arcs visited and v is the number of states visited. +// Constant time and space to visit an input state or arc is assumed and +// exclusive of caching. +template +class EncodeFst : public ArcMapFst> { + public: + using Mapper = EncodeMapper; + using Impl = internal::ArcMapFstImpl; + + EncodeFst(const Fst &fst, Mapper *encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) { + encoder->SetInputSymbols(fst.InputSymbols()); + encoder->SetOutputSymbols(fst.OutputSymbols()); + } + + EncodeFst(const Fst &fst, const Mapper &encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) {} + + // See Fst<>::Copy() for doc. + EncodeFst(const EncodeFst &fst, bool copy = false) + : ArcMapFst(fst, copy) {} + + // Makes a copy of this EncodeFst. See Fst<>::Copy() for further doc. + EncodeFst *Copy(bool safe = false) const override { + if (safe) { + FSTERROR() << "EncodeFst::Copy(true): Not allowed"; + GetImpl()->SetProperties(kError, kError); + } + return new EncodeFst(*this); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; +}; + +// On-the-fly decoding of an input FST. +// +// Complexity: +// +// Construction: O(1). +// Traversal: O(e + v) +// +// Constant time and space to visit an input state or arc is assumed and +// exclusive of caching. +template +class DecodeFst : public ArcMapFst> { + public: + using Mapper = EncodeMapper; + using Impl = internal::ArcMapFstImpl; + using ImplToFst::GetImpl; + + DecodeFst(const Fst &fst, const Mapper &encoder) + : ArcMapFst(fst, Mapper(encoder, DECODE), + ArcMapFstOptions()) { + GetMutableImpl()->SetInputSymbols(encoder.InputSymbols()); + GetMutableImpl()->SetOutputSymbols(encoder.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + DecodeFst(const DecodeFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Makes a copy of this DecodeFst. See Fst<>::Copy() for further doc. + DecodeFst *Copy(bool safe = false) const override { + return new DecodeFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for EncodeFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const EncodeFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for EncodeFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const EncodeFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Specialization for DecodeFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const DecodeFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for DecodeFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const DecodeFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful aliases when using StdArc. + +using StdEncodeFst = EncodeFst; + +using StdDecodeFst = DecodeFst; + +} // namespace fst + +#endif // FST_ENCODE_H_ diff --git a/projects/llm_framework/include/fst/epsnormalize.h b/projects/llm_framework/include/fst/epsnormalize.h new file mode 100644 index 00000000..18105fb1 --- /dev/null +++ b/projects/llm_framework/include/fst/epsnormalize.h @@ -0,0 +1,61 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function that implements epsilon-normalization. + +#ifndef FST_EPSNORMALIZE_H_ +#define FST_EPSNORMALIZE_H_ + + +#include +#include +#include +#include + + +namespace fst { + +enum EpsNormalizeType { EPS_NORM_INPUT, EPS_NORM_OUTPUT }; + +// Returns an equivalent FST that is epsilon-normalized. An acceptor is +// epsilon-normalized if it is epsilon-removed. A transducer is input +// epsilon-normalized if additionally if on each path any epsilon input +// label follows all non-epsilon input labels. Output epsilon-normalized +// is defined similarly. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science, 13(1): 129-143, 2002. +template +void EpsNormalize(const Fst &ifst, MutableFst *ofst, + EpsNormalizeType type = EPS_NORM_INPUT) { + EpsNormalize(ifst, ofst, type); +} + +// Same as above, except allows specifying explicitly the gallic weight type. +template +void EpsNormalize(const Fst &ifst, MutableFst *ofst, + EpsNormalizeType type) { + VectorFst> gfst; + std::unique_ptr symbols; + if (type == EPS_NORM_INPUT) { + ArcMap(ifst, &gfst, ToGallicMapper()); + if (ifst.OutputSymbols()) symbols.reset(ifst.OutputSymbols()->Copy()); + } else { // type == EPS_NORM_OUTPUT + ArcMap(InvertFst(ifst), &gfst, ToGallicMapper()); + if (ifst.InputSymbols()) symbols.reset(ifst.InputSymbols()->Copy()); + } + RmEpsilon(&gfst); + FactorWeightFst, + GallicFactor> + fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper()); + ofst->SetOutputSymbols(symbols.get()); + if (type == EPS_NORM_OUTPUT) Invert(ofst); +} + +} // namespace fst + +#endif // FST_EPSNORMALIZE_H_ diff --git a/projects/llm_framework/include/fst/equal.h b/projects/llm_framework/include/fst/equal.h new file mode 100644 index 00000000..ed89c6ce --- /dev/null +++ b/projects/llm_framework/include/fst/equal.h @@ -0,0 +1,169 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to test equality of two FSTs. + +#ifndef FST_EQUAL_H_ +#define FST_EQUAL_H_ + +#include + +#include +#include + + +namespace fst { + +constexpr uint32 kEqualFsts = 0x0001; +constexpr uint32 kEqualFstTypes = 0x0002; +constexpr uint32 kEqualCompatProperties = 0x0004; +constexpr uint32 kEqualCompatSymbols = 0x0008; +constexpr uint32 kEqualAll = + kEqualFsts | kEqualFstTypes | kEqualCompatProperties | kEqualCompatSymbols; + +class WeightApproxEqual { + public: + explicit WeightApproxEqual(float delta) : delta_(delta) {} + + template + bool operator()(const Weight &w1, const Weight &w2) const { + return ApproxEqual(w1, w2, delta_); + } + + private: + float delta_; +}; + +// Tests if two Fsts have the same states and arcs in the same order (when +// etype & kEqualFst). +// Also optional checks equality of Fst types (etype & kEqualFstTypes) and +// compatibility of stored properties (etype & kEqualCompatProperties) and +// of symbol tables (etype & kEqualCompatSymbols). +template +bool Equal(const Fst &fst1, const Fst &fst2, + WeightEqual weight_equal, uint32 etype = kEqualFsts) { + if ((etype & kEqualFstTypes) && (fst1.Type() != fst2.Type())) { + VLOG(1) << "Equal: Mismatched FST types (" << fst1.Type() << " != " + << fst2.Type() << ")"; + return false; + } + if ((etype & kEqualCompatProperties) && + !CompatProperties(fst1.Properties(kCopyProperties, false), + fst2.Properties(kCopyProperties, false))) { + VLOG(1) << "Equal: Properties not compatible"; + return false; + } + if (etype & kEqualCompatSymbols) { + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols(), false)) { + VLOG(1) << "Equal: Input symbols not compatible"; + return false; + } + if (!CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols(), false)) { + VLOG(1) << "Equal: Output symbols not compatible"; + return false; + } + } + if (!(etype & kEqualFsts)) return true; + if (fst1.Start() != fst2.Start()) { + VLOG(1) << "Equal: Mismatched start states (" << fst1.Start() << " != " + << fst2.Start() << ")"; + return false; + } + StateIterator> siter1(fst1); + StateIterator> siter2(fst2); + while (!siter1.Done() || !siter2.Done()) { + if (siter1.Done() || siter2.Done()) { + VLOG(1) << "Equal: Mismatched number of states"; + return false; + } + const auto s1 = siter1.Value(); + const auto s2 = siter2.Value(); + if (s1 != s2) { + VLOG(1) << "Equal: Mismatched states (" << s1 << "!= " + << s2 << ")"; + return false; + } + const auto &final1 = fst1.Final(s1); + const auto &final2 = fst2.Final(s2); + if (!weight_equal(final1, final2)) { + VLOG(1) << "Equal: Mismatched final weights at state " << s1 + << " (" << final1 << " != " << final2 << ")"; + return false; + } + ArcIterator> aiter1(fst1, s1); + ArcIterator> aiter2(fst2, s2); + for (auto a = 0; !aiter1.Done() || !aiter2.Done(); ++a) { + if (aiter1.Done() || aiter2.Done()) { + VLOG(1) << "Equal: Mismatched number of arcs at state " << s1; + return false; + } + const auto &arc1 = aiter1.Value(); + const auto &arc2 = aiter2.Value(); + if (arc1.ilabel != arc2.ilabel) { + VLOG(1) << "Equal: Mismatched arc input labels at state " << s1 + << ", arc " << a << " (" << arc1.ilabel << " != " + << arc2.ilabel << ")"; + return false; + } else if (arc1.olabel != arc2.olabel) { + VLOG(1) << "Equal: Mismatched arc output labels at state " << s1 + << ", arc " << a << " (" << arc1.olabel << " != " + << arc2.olabel << ")"; + return false; + } else if (!weight_equal(arc1.weight, arc2.weight)) { + VLOG(1) << "Equal: Mismatched arc weights at state " << s1 + << ", arc " << a << " (" << arc1.weight << " != " + << arc2.weight << ")"; + return false; + } else if (arc1.nextstate != arc2.nextstate) { + VLOG(1) << "Equal: Mismatched next state at state " << s1 + << ", arc " << a << " (" << arc1.nextstate << " != " + << arc2.nextstate << ")"; + return false; + } + aiter1.Next(); + aiter2.Next(); + } + // Sanity checks: should never fail. + if (fst1.NumArcs(s1) != fst2.NumArcs(s2)) { + FSTERROR() << "Equal: Inconsistent arc counts at state " << s1 + << " (" << fst1.NumArcs(s1) << " != " + << fst2.NumArcs(s2) << ")"; + return false; + } + if (fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2)) { + FSTERROR() << "Equal: Inconsistent input epsilon counts at state " << s1 + << " (" << fst1.NumInputEpsilons(s1) << " != " + << fst2.NumInputEpsilons(s2) << ")"; + return false; + } + if (fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) { + FSTERROR() << "Equal: Inconsistent output epsilon counts at state " << s1 + << " (" << fst1.NumOutputEpsilons(s1) << " != " + << fst2.NumOutputEpsilons(s2) << ")"; + } + siter1.Next(); + siter2.Next(); + } + return true; +} + +template +bool Equal(const Fst &fst1, const Fst &fst2, + float delta = kDelta, uint32 etype = kEqualFsts) { + return Equal(fst1, fst2, WeightApproxEqual(delta), etype); +} + +// Support double deltas without forcing all clients to cast to float. +// Without this overload, Equal will be chosen, +// since it is a better match than double -> float narrowing, but +// the instantiation will fail. +template +bool Equal(const Fst &fst1, const Fst &fst2, + double delta, uint32 etype = kEqualFsts) { + return Equal(fst1, fst2, WeightApproxEqual(static_cast(delta)), etype); +} + + +} // namespace fst + +#endif // FST_EQUAL_H_ diff --git a/projects/llm_framework/include/fst/equivalent.h b/projects/llm_framework/include/fst/equivalent.h new file mode 100644 index 00000000..cf3fdb61 --- /dev/null +++ b/projects/llm_framework/include/fst/equivalent.h @@ -0,0 +1,230 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to determine the equivalence of two FSTs. + +#ifndef FST_EQUIVALENT_H_ +#define FST_EQUIVALENT_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + + +namespace fst { +namespace internal { + +// Traits-like struct holding utility functions/typedefs/constants for +// the equivalence algorithm. +// +// Encoding device: in order to make the statesets of the two acceptors +// disjoint, we map Arc::StateId on the type MappedId. The states of +// the first acceptor are mapped on odd numbers (s -> 2s + 1), and +// those of the second one on even numbers (s -> 2s + 2). The number 0 +// is reserved for an implicit (non-final) dead state (required for +// the correct treatment of non-coaccessible states; kNoStateId is mapped to +// kDeadState for both acceptors). The union-find algorithm operates on the +// mapped IDs. +template +struct EquivalenceUtil { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MappedId = StateId; // ID for an equivalence class. + + // MappedId for an implicit dead state. + static constexpr MappedId kDeadState = 0; + + // MappedId for lookup failure. + static constexpr MappedId kInvalidId = -1; + + // Maps state ID to the representative of the corresponding + // equivalence class. The parameter 'which_fst' takes the values 1 + // and 2, identifying the input FST. + static MappedId MapState(StateId s, int32 which_fst) { + return (kNoStateId == s) ? kDeadState + : (static_cast(s) << 1) + which_fst; + } + + // Maps set ID to State ID. + static StateId UnMapState(MappedId id) { + return static_cast((--id) >> 1); + } + + // Convenience function: checks if state with MappedId s is final in + // acceptor fa. + static bool IsFinal(const Fst &fa, MappedId s) { + return (kDeadState == s) ? false + : (fa.Final(UnMapState(s)) != Weight::Zero()); + } + // Convenience function: returns the representative of ID in sets, + // creating a new set if needed. + static MappedId FindSet(UnionFind *sets, MappedId id) { + const auto repr = sets->FindSet(id); + if (repr != kInvalidId) { + return repr; + } else { + sets->MakeSet(id); + return id; + } + } +}; + +template +constexpr + typename EquivalenceUtil::MappedId EquivalenceUtil::kDeadState; + +template +constexpr + typename EquivalenceUtil::MappedId EquivalenceUtil::kInvalidId; + +} // namespace internal + +// Equivalence checking algorithm: determines if the two FSTs fst1 and fst2 +// are equivalent. The input FSTs must be deterministic input-side epsilon-free +// acceptors, unweighted or with weights over a left semiring. Two acceptors are +// considered equivalent if they accept exactly the same set of strings (with +// the same weights). +// +// The algorithm (cf. Aho, Hopcroft and Ullman, "The Design and Analysis of +// Computer Programs") successively constructs sets of states that can be +// reached by the same prefixes, starting with a set containing the start states +// of both acceptors. A disjoint tree forest (the union-find algorithm) is used +// to represent the sets of states. The algorithm returns false if one of the +// constructed sets contains both final and non-final states. Returns an +// optional error value (useful when FLAGS_error_fatal = false). +// +// Complexity: +// +// Quasi-linear, i.e., O(n G(n)), where +// +// n = |S1| + |S2| is the number of states in both acceptors +// +// G(n) is a very slowly growing function that can be approximated +// by 4 by all practical purposes. +template +bool Equivalent(const Fst &fst1, const Fst &fst2, + float delta = kDelta, bool *error = nullptr) { + using Weight = typename Arc::Weight; + if (error) *error = false; + // Check that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Equivalent: Input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + if (error) *error = true; + return false; + } + // Check properties first. + static constexpr auto props = kNoEpsilons | kIDeterministic | kAcceptor; + if (fst1.Properties(props, true) != props) { + FSTERROR() << "Equivalent: 1st argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + if (fst2.Properties(props, true) != props) { + FSTERROR() << "Equivalent: 2nd argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + if ((fst1.Properties(kUnweighted, true) != kUnweighted) || + (fst2.Properties(kUnweighted, true) != kUnweighted)) { + VectorFst efst1(fst1); + VectorFst efst2(fst2); + Push(&efst1, REWEIGHT_TO_INITIAL, delta); + Push(&efst2, REWEIGHT_TO_INITIAL, delta); + ArcMap(&efst1, QuantizeMapper(delta)); + ArcMap(&efst2, QuantizeMapper(delta)); + EncodeMapper mapper(kEncodeWeights | kEncodeLabels, ENCODE); + ArcMap(&efst1, &mapper); + ArcMap(&efst2, &mapper); + return Equivalent(efst1, efst2); + } + using Util = internal::EquivalenceUtil; + using MappedId = typename Util::MappedId; + enum { FST1 = 1, FST2 = 2 }; // Required by Util::MapState(...) + auto s1 = Util::MapState(fst1.Start(), FST1); + auto s2 = Util::MapState(fst2.Start(), FST2); + // The union-find structure. + UnionFind eq_classes(1000, Util::kInvalidId); + // Initializes the union-find structure. + eq_classes.MakeSet(s1); + eq_classes.MakeSet(s2); + // Data structure for the (partial) acceptor transition function of fst1 and + // fst2: input labels mapped to pairs of MappedIds representing destination + // states of the corresponding arcs in fst1 and fst2, respectively. + using Label2StatePairMap = + std::unordered_map>; + Label2StatePairMap arc_pairs; + // Pairs of MappedId's to be processed, organized in a queue. + std::deque> q; + bool ret = true; + // Returns early if the start states differ w.r.t. finality. + if (Util::IsFinal(fst1, s1) != Util::IsFinal(fst2, s2)) ret = false; + // Main loop: explores the two acceptors in a breadth-first manner, updating + // the equivalence relation on the statesets. Loop invariant: each block of + // the states contains either final states only or non-final states only. + for (q.push_back(std::make_pair(s1, s2)); ret && !q.empty(); q.pop_front()) { + s1 = q.front().first; + s2 = q.front().second; + // Representatives of the equivalence classes of s1/s2. + const auto rep1 = Util::FindSet(&eq_classes, s1); + const auto rep2 = Util::FindSet(&eq_classes, s2); + if (rep1 != rep2) { + eq_classes.Union(rep1, rep2); + arc_pairs.clear(); + // Copies outgoing arcs starting at s1 into the hash-table. + if (Util::kDeadState != s1) { + ArcIterator> arc_iter(fst1, Util::UnMapState(s1)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const auto &arc = arc_iter.Value(); + // Zero-weight arcs are treated as if they did not exist. + if (arc.weight != Weight::Zero()) { + arc_pairs[arc.ilabel].first = Util::MapState(arc.nextstate, FST1); + } + } + } + // Copies outgoing arcs starting at s2 into the hashtable. + if (Util::kDeadState != s2) { + ArcIterator> arc_iter(fst2, Util::UnMapState(s2)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const auto &arc = arc_iter.Value(); + // Zero-weight arcs are treated as if they did not exist. + if (arc.weight != Weight::Zero()) { + arc_pairs[arc.ilabel].second = Util::MapState(arc.nextstate, FST2); + } + } + } + // Iterates through the hashtable and process pairs of target states. + for (const auto &arc_iter : arc_pairs) { + const auto &pair = arc_iter.second; + if (Util::IsFinal(fst1, pair.first) != + Util::IsFinal(fst2, pair.second)) { + // Detected inconsistency: return false. + ret = false; + break; + } + q.push_back(pair); + } + } + } + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + return ret; +} + +} // namespace fst + +#endif // FST_EQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/expanded-fst.h b/projects/llm_framework/include/fst/expanded-fst.h new file mode 100644 index 00000000..2c7d514c --- /dev/null +++ b/projects/llm_framework/include/fst/expanded-fst.h @@ -0,0 +1,179 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Generic FST augmented with state count-interface class definition. + +#ifndef FST_EXPANDED_FST_H_ +#define FST_EXPANDED_FST_H_ + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +// A generic FST plus state count. +template +class ExpandedFst : public Fst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + virtual StateId NumStates() const = 0; // State count + + // Get a copy of this ExpandedFst. See Fst<>::Copy() for further doc. + ExpandedFst *Copy(bool safe = false) const override = 0; + + // Read an ExpandedFst from an input stream; return NULL on error. + static ExpandedFst *Read(std::istream &strm, + const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kExpanded)) { + LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source; + return nullptr; + } + const auto reader = + FstRegister::GetRegister()->GetReader(hdr.FstType()); + if (!reader) { + LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType() + << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source; + return nullptr; + } + auto *fst = reader(strm, ropts); + if (!fst) return nullptr; + return static_cast *>(fst); + } + + // Read an ExpandedFst from a file; return NULL on error. + // Empty filename reads from standard input. + static ExpandedFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } +}; + +namespace internal { + +// ExpandedFst case - abstract methods. +template +inline typename Arc::Weight Final(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline ssize_t NumArcs(const ExpandedFst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + +// A useful alias when using StdArc. +using StdExpandedFst = ExpandedFst; + +// This is a helper class template useful for attaching an ExpandedFst +// interface to its implementation, handling reference counting. It +// delegates to ImplToFst the handling of the Fst interface methods. +template > +class ImplToExpandedFst : public ImplToFst { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateId NumStates() const override { return GetImpl()->NumStates(); } + + protected: + using ImplToFst::GetImpl; + + explicit ImplToExpandedFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + ImplToExpandedFst(const ImplToExpandedFst &fst, bool safe) + : ImplToFst(fst, safe) {} + + static Impl *Read(std::istream &strm, const FstReadOptions &opts) { + return Impl::Read(strm, opts); + } + + // Read FST implementation from a file; return NULL on error. + // Empty filename reads from standard input. + static Impl *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return nullptr; + } + return Impl::Read(strm, FstReadOptions(filename)); + } else { + return Impl::Read(std::cin, FstReadOptions("standard input")); + } + } +}; + +// Function to return the number of states in an FST, counting them +// if necessary. +template +typename Arc::StateId CountStates(const Fst &fst) { + if (fst.Properties(kExpanded, false)) { + const auto *efst = static_cast *>(&fst); + return efst->NumStates(); + } else { + typename Arc::StateId nstates = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates; + } + return nstates; + } +} + +// Function to return the number of arcs in an FST. +template +typename Arc::StateId CountArcs(const Fst &fst) { + size_t narcs = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + narcs += fst.NumArcs(siter.Value()); + } + return narcs; +} + +} // namespace fst + +#endif // FST_EXPANDED_FST_H_ diff --git a/projects/llm_framework/include/fst/expectation-weight.h b/projects/llm_framework/include/fst/expectation-weight.h new file mode 100644 index 00000000..f996cbc6 --- /dev/null +++ b/projects/llm_framework/include/fst/expectation-weight.h @@ -0,0 +1,134 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expectation semiring as described by Jason Eisner: +// See: doi=10.1.1.22.9398 +// Multiplex semiring operations and identities: +// One: +// Zero: +// Plus: + = < (a1 + a2) , (b1 + b2) > +// Times: * = < (a1 * a2) , [(a1 * b2) + (a2 * b1)] > +// Division: Undefined (currently) +// +// Usually used to store the pair so that +// ShortestDistance[Fst>>] +// == < PosteriorProbability, Expected_Value[V] > + +#ifndef FST_EXPECTATION_WEIGHT_H_ +#define FST_EXPECTATION_WEIGHT_H_ + +#include + +#include + +#include +#include + + +namespace fst { + +// X1 is usually a probability weight like LogWeight. +// X2 is usually a random variable or vector (see SignedLogWeight or +// SparsePowerWeight). +// +// If X1 is distinct from X2, it is required that there is an external product +// between X1 and X2 and if both semriring are commutative, or left or right +// semirings, then result must have those properties. +template +class ExpectationWeight : public PairWeight { + public: + using PairWeight::Value1; + using PairWeight::Value2; + + using PairWeight::Reverse; + using PairWeight::Quantize; + using PairWeight::Member; + + using ReverseWeight = + ExpectationWeight; + + ExpectationWeight() : PairWeight(Zero()) {} + + explicit ExpectationWeight(const PairWeight &weight) + : PairWeight(weight) {} + + ExpectationWeight(const X1 &x1, const X2 &x2) : PairWeight(x1, x2) {} + + static const ExpectationWeight &Zero() { + static const ExpectationWeight zero(X1::Zero(), X2::Zero()); + return zero; + } + + static const ExpectationWeight &One() { + static const ExpectationWeight one(X1::One(), X2::Zero()); + return one; + } + + static const ExpectationWeight &NoWeight() { + static const ExpectationWeight no_weight(X1::NoWeight(), X2::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string("expectation_" + X1::Type() + "_" + X2::Type()); + return *type; + } + + PairWeight Quantize(float delta = kDelta) const { + return ExpectationWeight(PairWeight::Quantize()); + } + + ReverseWeight Reverse() const { + return ReverseWeight(PairWeight::Reverse()); + } + + bool Member() const { return PairWeight::Member(); } + + static constexpr uint64 Properties() { + return X1::Properties() & X2::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } +}; + +template +inline ExpectationWeight Plus(const ExpectationWeight &w1, + const ExpectationWeight &w2) { + return ExpectationWeight(Plus(w1.Value1(), w2.Value1()), + Plus(w1.Value2(), w2.Value2())); +} + +template +inline ExpectationWeight Times(const ExpectationWeight &w1, + const ExpectationWeight &w2) { + return ExpectationWeight( + Times(w1.Value1(), w2.Value1()), + Plus(Times(w1.Value1(), w2.Value2()), Times(w1.Value2(), w2.Value1()))); +} + +template +inline ExpectationWeight Divide(const ExpectationWeight &w1, + const ExpectationWeight &w2, + DivideType typ = DIVIDE_ANY) { + FSTERROR() << "ExpectationWeight::Divide: Not implemented"; + return ExpectationWeight::NoWeight(); +} + +// This function object generates weights by calling the underlying generators +// for the template weight types, like all other pair weight types. This is +// intended primarily for testing. +template +class WeightGenerate> + : public WeightGenerate> { + public: + using Weight = ExpectationWeight; + using Generate = WeightGenerate>; + + explicit WeightGenerate(bool allow_zero = true) : Generate(allow_zero) {} + + Weight operator()() const { return Weight(Generate::operator()()); } +}; + +} // namespace fst + +#endif // FST_EXPECTATION_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/extensions/compress/compress-script.h b/projects/llm_framework/include/fst/extensions/compress/compress-script.h new file mode 100644 index 00000000..bad238aa --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/compress/compress-script.h @@ -0,0 +1,53 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Declarations of 'scriptable' versions of compression operations, that is, +// those that can be called with FstClass-type arguments. + +#ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ +#define FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +typedef std::tuple CompressArgs; + +template +void Compress(CompressArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + const string &filename = std::get<1>(*args); + const bool gzip = std::get<2>(*args); + + if (!fst::Compress(fst, filename, gzip)) FSTERROR() << "Compress: failed"; +} + +void Compress(const FstClass &fst, const string &filename, const bool gzip); + +typedef std::tuple + DecompressArgs; + +template +void Decompress(DecompressArgs *args) { + const string &filename = std::get<0>(*args); + MutableFst *fst = std::get<1>(*args)->GetMutableFst(); + const bool gzip = std::get<2>(*args); + + if (!fst::Decompress(filename, fst, gzip)) + FSTERROR() << "Decompress: failed"; +} + +void Decompress(const string &filename, MutableFstClass *fst, const bool gzip); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/compress/compress.h b/projects/llm_framework/include/fst/extensions/compress/compress.h new file mode 100644 index 00000000..aa94848f --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/compress/compress.h @@ -0,0 +1,906 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Compresses and decompresses unweighted FSTs. + +#ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ +#define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Identifies stream data as a vanilla compressed FST. +static const int32 kCompressMagicNumber = 1858869554; +// Identifies stream data as (probably) a Gzip file accidentally read from +// a vanilla stream, without gzip support. +static const int32 kGzipMagicNumber = 0x8b1f; +// Selects the two most significant bytes. +constexpr uint32 kGzipMask = 0xffffffff >> 16; + +namespace internal { + +// Expands a Lempel Ziv code and returns the set of code words. expanded_code[i] +// is the i^th Lempel Ziv codeword. +template +bool ExpandLZCode(const std::vector> &code, + std::vector> *expanded_code) { + expanded_code->resize(code.size()); + for (int i = 0; i < code.size(); ++i) { + if (code[i].first > i) { + LOG(ERROR) << "ExpandLZCode: Not a valid code"; + return false; + } + if (code[i].first == 0) { + (*expanded_code)[i].resize(1, code[i].second); + } else { + (*expanded_code)[i].resize((*expanded_code)[code[i].first - 1].size() + + 1); + std::copy((*expanded_code)[code[i].first - 1].begin(), + (*expanded_code)[code[i].first - 1].end(), + (*expanded_code)[i].begin()); + (*expanded_code)[i][(*expanded_code)[code[i].first - 1].size()] = + code[i].second; + } + } + return true; +} + +} // namespace internal + +// Lempel Ziv on data structure Edge, with a less than operator +// EdgeLessThan and an equals operator EdgeEquals. +// Edge has a value defaultedge which it never takes and +// Edge is defined, it is initialized to defaultedge +template +class LempelZiv { + public: + LempelZiv() : dict_number_(0), default_edge_() { + root_.current_number = dict_number_++; + root_.current_edge = default_edge_; + decode_vector_.push_back(std::make_pair(0, default_edge_)); + } + // Encodes a vector input into output + void BatchEncode(const std::vector &input, + std::vector> *output); + + // Decodes codedvector to output. Returns false if + // the index exceeds the size. + bool BatchDecode(const std::vector> &input, + std::vector *output); + + // Decodes a single dictionary element. Returns false + // if the index exceeds the size. + bool SingleDecode(const Var &index, Edge *output) { + if (index >= decode_vector_.size()) { + LOG(ERROR) << "LempelZiv::SingleDecode: " + << "Index exceeded the dictionary size"; + return false; + } else { + *output = decode_vector_[index].second; + return true; + } + } + + ~LempelZiv() { + for (auto it = (root_.next_number).begin(); it != (root_.next_number).end(); + ++it) { + CleanUp(it->second); + } + } + // Adds a single dictionary element while decoding + // void AddDictElement(const std::pair &newdict) { + // EdgeEquals InstEdgeEquals; + // if (InstEdgeEquals(newdict.second, default_edge_) != 1) + // decode_vector_.push_back(newdict); + // } + + private: + // Node datastructure is used for encoding + + struct Node { + Var current_number; + Edge current_edge; + std::map next_number; + }; + + void CleanUp(Node *temp) { + for (auto it = (temp->next_number).begin(); it != (temp->next_number).end(); + ++it) { + CleanUp(it->second); + } + delete temp; + } + Node root_; + Var dict_number_; + // decode_vector_ is used for decoding + std::vector> decode_vector_; + Edge default_edge_; +}; + +template +void LempelZiv::BatchEncode( + const std::vector &input, std::vector> *output) { + for (typename std::vector::const_iterator it = input.begin(); + it != input.end(); ++it) { + Node *temp_node = &root_; + while (it != input.end()) { + auto next = (temp_node->next_number).find(*it); + if (next != (temp_node->next_number).end()) { + temp_node = next->second; + ++it; + } else { + break; + } + } + if (it == input.end() && temp_node->current_number != 0) { + output->push_back( + std::make_pair(temp_node->current_number, default_edge_)); + } else if (it != input.end()) { + output->push_back(std::make_pair(temp_node->current_number, *it)); + Node *new_node = new (Node); + new_node->current_number = dict_number_++; + new_node->current_edge = *it; + (temp_node->next_number)[*it] = new_node; + } + if (it == input.end()) break; + } +} + +template +bool LempelZiv::BatchDecode( + const std::vector> &input, std::vector *output) { + for (typename std::vector>::const_iterator it = + input.begin(); + it != input.end(); ++it) { + std::vector temp_output; + EdgeEquals InstEdgeEquals; + if (InstEdgeEquals(it->second, default_edge_) != 1) { + decode_vector_.push_back(*it); + temp_output.push_back(it->second); + } + Var temp_integer = it->first; + if (temp_integer >= decode_vector_.size()) { + LOG(ERROR) << "LempelZiv::BatchDecode: " + << "Index exceeded the dictionary size"; + return false; + } else { + while (temp_integer != 0) { + temp_output.push_back(decode_vector_[temp_integer].second); + temp_integer = decode_vector_[temp_integer].first; + } + std::reverse(temp_output.begin(), temp_output.end()); + output->insert(output->end(), temp_output.begin(), temp_output.end()); + } + } + return true; +} + +// The main Compressor class +template +class Compressor { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + Compressor() {} + + // Compresses fst into a boolean vector code. Returns true on sucesss. + bool Compress(const Fst &fst, std::ostream &strm); + + // Decompresses the boolean vector into Fst. Returns true on sucesss. + bool Decompress(std::istream &strm, const string &source, + MutableFst *fst); + + // Finds the BFS order of a fst + void BfsOrder(const ExpandedFst &fst, std::vector *order); + + // Preprocessing step to convert fst to a isomorphic fst + // Returns a preproccess fst and a dictionary + void Preprocess(const Fst &fst, MutableFst *preprocessedfst, + EncodeMapper *encoder); + + // Performs Lempel Ziv and outputs a stream of integers + // and sends it to a stream + void EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm); + + // Decodes fst from the stream + void DecodeProcessedFst(const std::vector &input, + MutableFst *fst, bool unweighted); + + // Converts buffer_code_ to uint8 and writes to a stream. + + // Writes the boolean file to the stream + void WriteToStream(std::ostream &strm); + + // Writes the weights to the stream + void WriteWeight(const std::vector &input, std::ostream &strm); + + void ReadWeight(std::istream &strm, std::vector *output); + + // Same as fst::Decode without the line RmFinalEpsilon(fst) + void DecodeForCompress(MutableFst *fst, const EncodeMapper &mapper); + + // Updates the buffer_code_ + template + void WriteToBuffer(CVar input) { + std::vector current_code; + Elias::DeltaEncode(input, ¤t_code); + if (!buffer_code_.empty()) { + buffer_code_.insert(buffer_code_.end(), current_code.begin(), + current_code.end()); + } else { + buffer_code_.assign(current_code.begin(), current_code.end()); + } + } + + private: + struct LZLabel { + LZLabel() : label(0) {} + Label label; + }; + + struct LabelLessThan { + bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { + return labelone.label < labeltwo.label; + } + }; + + struct LabelEquals { + bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { + return labelone.label == labeltwo.label; + } + }; + + struct Transition { + Transition() : nextstate(0), label(0), weight(Weight::Zero()) {} + + StateId nextstate; + Label label; + Weight weight; + }; + + struct TransitionLessThan { + bool operator()(const Transition &transition_one, + const Transition &transition_two) const { + if (transition_one.nextstate == transition_two.nextstate) + return transition_one.label < transition_two.label; + else + return transition_one.nextstate < transition_two.nextstate; + } + } transition_less_than; + + struct TransitionEquals { + bool operator()(const Transition &transition_one, + const Transition &transition_two) const { + return transition_one.nextstate == transition_two.nextstate && + transition_one.label == transition_two.label; + } + } transition_equals; + + struct OldDictCompare { + bool operator()(const std::pair &pair_one, + const std::pair &pair_two) const { + if ((pair_one.second).nextstate == (pair_two.second).nextstate) + return (pair_one.second).label < (pair_two.second).label; + else + return (pair_one.second).nextstate < (pair_two.second).nextstate; + } + } old_dict_compare; + + std::vector buffer_code_; + std::vector arc_weight_; + std::vector final_weight_; +}; + +template +inline void Compressor::DecodeForCompress( + MutableFst *fst, const EncodeMapper &mapper) { + ArcMap(fst, EncodeMapper(mapper, DECODE)); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + +// Compressor::BfsOrder +template +void Compressor::BfsOrder(const ExpandedFst &fst, + std::vector *order) { + Arc arc; + StateId bfs_visit_number = 0; + std::queue states_queue; + order->assign(fst.NumStates(), kNoStateId); + states_queue.push(fst.Start()); + (*order)[fst.Start()] = bfs_visit_number++; + while (!states_queue.empty()) { + for (ArcIterator> aiter(fst, states_queue.front()); !aiter.Done(); + aiter.Next()) { + arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if ((*order)[nextstate] == kNoStateId) { + (*order)[nextstate] = bfs_visit_number++; + states_queue.push(nextstate); + } + } + states_queue.pop(); + } + + // If the FST is unconnected, then the following + // code finds them + while (bfs_visit_number < fst.NumStates()) { + int unseen_state = 0; + for (unseen_state = 0; unseen_state < fst.NumStates(); ++unseen_state) { + if ((*order)[unseen_state] == kNoStateId) break; + } + states_queue.push(unseen_state); + (*order)[unseen_state] = bfs_visit_number++; + while (!states_queue.empty()) { + for (ArcIterator> aiter(fst, states_queue.front()); + !aiter.Done(); aiter.Next()) { + arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if ((*order)[nextstate] == kNoStateId) { + (*order)[nextstate] = bfs_visit_number++; + states_queue.push(nextstate); + } + } + states_queue.pop(); + } + } +} + +template +void Compressor::Preprocess(const Fst &fst, + MutableFst *preprocessedfst, + EncodeMapper *encoder) { + *preprocessedfst = fst; + if (!preprocessedfst->NumStates()) { + return; + } + // Relabels the edges and develops a dictionary + Encode(preprocessedfst, encoder); + std::vector order; + // Finds the BFS sorting order of the fst + BfsOrder(*preprocessedfst, &order); + // Reorders the states according to the BFS order + StateSort(preprocessedfst, order); +} + +template +void Compressor::EncodeProcessedFst(const ExpandedFst &fst, + std::ostream &strm) { + std::vector output; + LempelZiv dict_new; + LempelZiv dict_old; + std::vector current_new_input; + std::vector current_old_input; + std::vector> current_new_output; + std::vector> current_old_output; + std::vector final_states; + + StateId number_of_states = fst.NumStates(); + + StateId seen_states = 0; + // Adding the number of states + WriteToBuffer(number_of_states); + + for (StateId state = 0; state < number_of_states; ++state) { + current_new_input.clear(); + current_old_input.clear(); + current_new_output.clear(); + current_old_output.clear(); + if (state > seen_states) ++seen_states; + + // Collecting the final states + if (fst.Final(state) != Weight::Zero()) { + final_states.push_back(state); + final_weight_.push_back(fst.Final(state)); + } + + // Reading the states + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.nextstate > seen_states) { // RILEY: > or >= ? + ++seen_states; + LZLabel temp_label; + temp_label.label = arc.ilabel; + arc_weight_.push_back(arc.weight); + current_new_input.push_back(temp_label); + } else { + Transition temp_transition; + temp_transition.nextstate = arc.nextstate; + temp_transition.label = arc.ilabel; + temp_transition.weight = arc.weight; + current_old_input.push_back(temp_transition); + } + } + // Adding new states + dict_new.BatchEncode(current_new_input, ¤t_new_output); + WriteToBuffer(current_new_output.size()); + + for (auto it = current_new_output.begin(); it != current_new_output.end(); + ++it) { + WriteToBuffer(it->first); + WriteToBuffer. +// See the FarReader interface in far.h for the exact semantics. +class FarReaderImplBase { + public: + virtual const string &ArcType() const = 0; + virtual bool Done() const = 0; + virtual bool Error() const = 0; + virtual const string &GetKey() const = 0; + virtual const FstClass *GetFstClass() const = 0; + virtual bool Find(const string &key) = 0; + virtual void Next() = 0; + virtual void Reset() = 0; + virtual FarType Type() const = 0; + virtual ~FarReaderImplBase() {} +}; + +// Templated implementation. +template +class FarReaderClassImpl : public FarReaderImplBase { + public: + explicit FarReaderClassImpl(const string &filename) + : impl_(FarReader::Open(filename)) {} + + explicit FarReaderClassImpl(const std::vector &filenames) + : impl_(FarReader::Open(filenames)) {} + + const string &ArcType() const final { return Arc::Type(); } + + bool Done() const final { return impl_->Done(); } + + bool Error() const final { return impl_->Error(); } + + bool Find(const string &key) final { return impl_->Find(key); } + + const FstClass *GetFstClass() const final { + fstc_.reset(new FstClass(*impl_->GetFst())); + return fstc_.get(); + } + + const string &GetKey() const final { return impl_->GetKey(); } + + void Next() final { return impl_->Next(); } + + void Reset() final { impl_->Reset(); } + + FarType Type() const final { return impl_->Type(); } + + const FarReader *GetImpl() const { return impl_.get(); } + + FarReader *GetImpl() { return impl_.get(); } + + private: + std::unique_ptr> impl_; + mutable std::unique_ptr fstc_; +}; + + +class FarReaderClass; + +using OpenFarReaderClassArgs = + WithReturnValue &>; + +// Untemplated user-facing class holding a templated pimpl. +class FarReaderClass { + public: + const string &ArcType() const { return impl_->ArcType(); } + + bool Done() const { return impl_->Done(); } + + // Returns True if the impl is null (i.e., due to read failure). + // Attempting to call any other function will result in null dereference. + bool Error() const { return (impl_) ? impl_->Error() : true; } + + bool Find(const string &key) { return impl_->Find(key); } + + const FstClass *GetFstClass() const { return impl_->GetFstClass(); } + + const string &GetKey() const { return impl_->GetKey(); } + + void Next() { impl_->Next(); } + + void Reset() { impl_->Reset(); } + + FarType Type() const { return impl_->Type(); } + + template + const FarReader *GetFarReader() const { + if (Arc::Type() != ArcType()) return nullptr; + const FarReaderClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + FarReader *GetFarReader() { + if (Arc::Type() != ArcType()) return nullptr; + FarReaderClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + friend void OpenFarReaderClass(OpenFarReaderClassArgs *args); + + // Defined in the CC. + + static FarReaderClass *Open(const string &filename); + + static FarReaderClass *Open(const std::vector &filenames); + + private: + template + explicit FarReaderClass(FarReaderClassImpl *impl) : impl_(impl) {} + + std::unique_ptr impl_; +}; + +// These exist solely for registration purposes; users should call the +// static method FarReaderClass::Open instead. + +template +void OpenFarReaderClass(OpenFarReaderClassArgs *args) { + args->retval = new FarReaderClass(new FarReaderClassImpl(args->args)); +} + +// FarWriter API. + +// Virtual interface implemented by each concrete FarWriterImpl. +class FarWriterImplBase { + public: + // Unlike the lower-level library, this returns a boolean to signal failure + // due to non-conformant arc types. + virtual bool Add(const string &key, const FstClass &fst) = 0; + virtual const string &ArcType() const = 0; + virtual bool Error() const = 0; + virtual FarType Type() const = 0; + virtual ~FarWriterImplBase() {} +}; + + +// Templated implementation. +template +class FarWriterClassImpl : public FarWriterImplBase { + public: + explicit FarWriterClassImpl(const string &filename, + FarType type = FAR_DEFAULT) + : impl_(FarWriter::Create(filename, type)) {} + + bool Add(const string &key, const FstClass &fst) final { + if (ArcType() != fst.ArcType()) { + FSTERROR() << "Cannot write FST with " << fst.ArcType() << " arcs to " + << "FAR with " << ArcType() << " arcs"; + return false; + } + impl_->Add(key, *(fst.GetFst())); + return true; + } + + const string &ArcType() const final { return Arc::Type(); } + + bool Error() const final { return impl_->Error(); } + + FarType Type() const final { return impl_->Type(); } + + const FarWriter *GetImpl() const { return impl_.get(); } + + FarWriter *GetImpl() { return impl_.get(); } + + private: + std::unique_ptr> impl_; +}; + + +class FarWriterClass; + +using CreateFarWriterClassInnerArgs = std::pair; + +using CreateFarWriterClassArgs = + WithReturnValue; + +// Untemplated user-facing class holding a templated pimpl. +class FarWriterClass { + public: + static FarWriterClass *Create(const string &filename, const string &arc_type, + FarType type = FAR_DEFAULT); + + bool Add(const string &key, const FstClass &fst) { + return impl_->Add(key, fst); + } + + // Returns True if the impl is null (i.e., due to construction failure). + // Attempting to call any other function will result in null dereference. + bool Error() const { return (impl_) ? impl_->Error() : true; } + + const string &ArcType() const { return impl_->ArcType(); } + + FarType Type() const { return impl_->Type(); } + + template + const FarWriter *GetFarWriter() const { + if (Arc::Type() != ArcType()) return nullptr; + const FarWriterClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + FarWriter *GetFarWriter() { + if (Arc::Type() != ArcType()) return nullptr; + FarWriterClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + friend void CreateFarWriterClass(CreateFarWriterClassArgs *args); + + private: + template + explicit FarWriterClass(FarWriterClassImpl *impl) : impl_(impl) {} + + std::unique_ptr impl_; +}; + +// This exists solely for registration purposes; users should call the +// static method FarWriterClass::Create instead. +template +void CreateFarWriterClass(CreateFarWriterClassArgs *args) { + args->retval = new FarWriterClass(new FarWriterClassImpl( + std::get<0>(args->args), std::get<1>(args->args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/far.h b/projects/llm_framework/include/fst/extensions/far/far.h new file mode 100644 index 00000000..c24c7dab --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/far.h @@ -0,0 +1,481 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Finite-State Transducer (FST) archive classes. + +#ifndef FST_EXTENSIONS_FAR_FAR_H_ +#define FST_EXTENSIONS_FAR_FAR_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace fst { + +enum FarEntryType { FET_LINE, FET_FILE }; + +enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; + +inline bool IsFst(const string &filename) { + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) return false; + return IsFstHeader(strm, filename); +} + +// FST archive header class +class FarHeader { + public: + const string &ArcType() const { return arctype_; } + + const string &FarType() const { return fartype_; } + + bool Read(const string &filename) { + FstHeader fsthdr; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; + } else if (IsSTTable(filename)) { // Checks if STTable. + ReadSTTableHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsSTList(filename)) { // Checks if STList. + ReadSTListHeader(filename, &fsthdr); + fartype_ = "stlist"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsFst(filename)) { // Checks if FST. + std::ifstream istrm(filename, + std::ios_base::in | std::ios_base::binary); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } + return false; + } + + private: + string fartype_; + string arctype_; +}; + +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; + +// This class creates an archive of FSTs. +template +class FarWriter { + public: + using Arc = A; + + // Creates a new (empty) FST archive; returns null on error. + static FarWriter *Create(const string &filename, FarType type = FAR_DEFAULT); + + // Adds an FST to the end of an archive. Keys must be non-empty and + // in lexicographic order. FSTs must have a suitable write method. + virtual void Add(const string &key, const Fst &fst) = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarWriter() {} + + protected: + FarWriter() {} +}; + +// This class iterates through an existing archive of FSTs. +template +class FarReader { + public: + using Arc = A; + + // Opens an existing FST archive in a single file; returns null on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const string &filename); + + // Opens an existing FST archive in multiple files; returns null on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const std::vector &filenames); + + // Resets current position to beginning of archive. + virtual void Reset() = 0; + + // Sets current position to first entry >= key. Returns true if a match. + virtual bool Find(const string &key) = 0; + + // Current position at end of archive? + virtual bool Done() const = 0; + + // Move current position to next FST. + virtual void Next() = 0; + + // Returns key at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const string &GetKey() const = 0; + + // Returns pointer to FST at the current position. This is invalidated if + // the current position in the archive is changed. + virtual const Fst *GetFst() const = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarReader() {} + + protected: + FarReader() {} +}; + +template +class FstWriter { + public: + void operator()(std::ostream &strm, const Fst &fst) const { + fst.Write(strm, FstWriteOptions()); + } +}; + +template +class STTableFarWriter : public FarWriter { + public: + using Arc = A; + + static STTableFarWriter *Create(const string &filename) { + auto *writer = STTableWriter, FstWriter>::Create(filename); + return new STTableFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) final { + writer_->Add(key, fst); + } + + FarType Type() const final { return FAR_STTABLE; } + + bool Error() const final { return writer_->Error(); } + + private: + explicit STTableFarWriter(STTableWriter, FstWriter> *writer) + : writer_(writer) {} + + std::unique_ptr, FstWriter>> writer_; +}; + +template +class STListFarWriter : public FarWriter { + public: + using Arc = A; + + static STListFarWriter *Create(const string &filename) { + auto *writer = STListWriter, FstWriter>::Create(filename); + return new STListFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) final { + writer_->Add(key, fst); + } + + constexpr FarType Type() const final { return FAR_STLIST; } + + bool Error() const final { return writer_->Error(); } + + private: + explicit STListFarWriter(STListWriter, FstWriter> *writer) + : writer_(writer) {} + + std::unique_ptr, FstWriter>> writer_; +}; + +template +class FstFarWriter : public FarWriter { + public: + using Arc = A; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst &fst) final { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one FST supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + constexpr FarType Type() const final { return FAR_FST; } + + bool Error() const final { return error_; } + + ~FstFarWriter() final {} + + private: + string filename_; + bool error_; + bool written_; +}; + +template +FarWriter *FarWriter::Create(const string &filename, FarType type) { + switch (type) { + case FAR_DEFAULT: + if (filename.empty()) return STListFarWriter::Create(filename); + case FAR_STTABLE: + return STTableFarWriter::Create(filename); + case FAR_STLIST: + return STListFarWriter::Create(filename); + case FAR_FST: + return FstFarWriter::Create(filename); + default: + LOG(ERROR) << "FarWriter::Create: Unknown FAR type"; + return nullptr; + } +} + +template +class FstReader { + public: + Fst *operator()(std::istream &strm) const { + return Fst::Read(strm, FstReadOptions()); + } +}; + +template +class STTableFarReader : public FarReader { + public: + using Arc = A; + + static STTableFarReader *Open(const string &filename) { + auto *reader = STTableReader, FstReader>::Open(filename); + if (!reader || reader->Error()) return nullptr; + return new STTableFarReader(reader); + } + + static STTableFarReader *Open(const std::vector &filenames) { + auto *reader = STTableReader, FstReader>::Open(filenames); + if (!reader || reader->Error()) return nullptr; + return new STTableFarReader(reader); + } + + void Reset() final { reader_->Reset(); } + + bool Find(const string &key) final { return reader_->Find(key); } + + bool Done() const final { return reader_->Done(); } + + void Next() final { return reader_->Next(); } + + const string &GetKey() const final { return reader_->GetKey(); } + + const Fst *GetFst() const final { return reader_->GetEntry(); } + + constexpr FarType Type() const final { return FAR_STTABLE; } + + bool Error() const final { return reader_->Error(); } + + private: + explicit STTableFarReader(STTableReader, FstReader> *reader) + : reader_(reader) {} + + std::unique_ptr, FstReader>> reader_; +}; + +template +class STListFarReader : public FarReader { + public: + using Arc = A; + + static STListFarReader *Open(const string &filename) { + auto *reader = STListReader, FstReader>::Open(filename); + if (!reader || reader->Error()) return nullptr; + return new STListFarReader(reader); + } + + static STListFarReader *Open(const std::vector &filenames) { + auto *reader = STListReader, FstReader>::Open(filenames); + if (!reader || reader->Error()) return nullptr; + return new STListFarReader(reader); + } + + void Reset() final { reader_->Reset(); } + + bool Find(const string &key) final { return reader_->Find(key); } + + bool Done() const final { return reader_->Done(); } + + void Next() final { return reader_->Next(); } + + const string &GetKey() const final { return reader_->GetKey(); } + + const Fst *GetFst() const final { return reader_->GetEntry(); } + + constexpr FarType Type() const final { return FAR_STLIST; } + + bool Error() const final { return reader_->Error(); } + + private: + explicit STListFarReader(STListReader, FstReader> *reader) + : reader_(reader) {} + + std::unique_ptr, FstReader>> reader_; +}; + +template +class FstFarReader : public FarReader { + public: + using Arc = A; + + static FstFarReader *Open(const string &filename) { + std::vector filenames; + filenames.push_back(filename); + return new FstFarReader(filenames); + } + + static FstFarReader *Open(const std::vector &filenames) { + return new FstFarReader(filenames); + } + + explicit FstFarReader(const std::vector &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), error_(false) { + std::sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &std::cin; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: standard input should " + "only appear once in the input file list"; + error_ = true; + return; + } + } else { + streams_[i] = new std::ifstream( + keys_[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "FstFarReader::FstFarReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() final { + if (has_stdin_) { + FSTERROR() + << "FstFarReader::Reset: Operation not supported on standard input"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) final { + if (has_stdin_) { + FSTERROR() + << "FstFarReader::Find: Operation not supported on standard input"; + error_ = true; + return false; + } + pos_ = 0; // TODO + ReadFst(); + return true; + } + + bool Done() const final { return error_ || pos_ >= keys_.size(); } + + void Next() final { + ++pos_; + ReadFst(); + } + + const string &GetKey() const final { return keys_[pos_]; } + + const Fst *GetFst() const final { return fst_.get(); } + + constexpr FarType Type() const final { return FAR_FST; } + + bool Error() const final { return error_; } + + ~FstFarReader() final { + for (size_t i = 0; i < keys_.size(); ++i) { + if (streams_[i] != &std::cin) { + delete streams_[i]; + } + } + } + + private: + void ReadFst() { + fst_.reset(); + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_.reset(Fst::Read(*streams_[pos_], FstReadOptions())); + if (!fst_) { + FSTERROR() << "FstFarReader: Error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + std::vector keys_; + std::vector streams_; + bool has_stdin_; + size_t pos_; + mutable std::unique_ptr> fst_; + mutable bool error_; +}; + +template +FarReader *FarReader::Open(const string &filename) { + if (filename.empty()) + return STListFarReader::Open(filename); + else if (IsSTTable(filename)) + return STTableFarReader::Open(filename); + else if (IsSTList(filename)) + return STListFarReader::Open(filename); + else if (IsFst(filename)) + return FstFarReader::Open(filename); + return nullptr; +} + +template +FarReader *FarReader::Open(const std::vector &filenames) { + if (!filenames.empty() && filenames[0].empty()) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsSTTable(filenames[0])) + return STTableFarReader::Open(filenames); + else if (!filenames.empty() && IsSTList(filenames[0])) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader::Open(filenames); + return nullptr; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/farlib.h b/projects/llm_framework/include/fst/extensions/far/farlib.h new file mode 100644 index 00000000..c9bb1710 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/farlib.h @@ -0,0 +1,19 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A finite-state archive (FAR) is used to store an indexable collection of +// FSTs in a single file. Utilities are provided to create FARs from FSTs, +// to iterate over FARs, and to extract specific FSTs from FARs. + +#ifndef FST_EXTENSIONS_FAR_FARLIB_H_ +#define FST_EXTENSIONS_FAR_FARLIB_H_ + +#include +#include +#include +#include +#include +#include +#include + +#endif // FST_EXTENSIONS_FAR_FARLIB_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/farscript.h b/projects/llm_framework/include/fst/extensions/far/farscript.h new file mode 100644 index 00000000..4bd11a94 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/farscript.h @@ -0,0 +1,269 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Convenience file for including all of the FAR operations, or registering +// them for new arc types. + +#ifndef FST_EXTENSIONS_FAR_FARSCRIPT_H_ +#define FST_EXTENSIONS_FAR_FARSCRIPT_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because this struct is +// only used to pass them deeper in the call graph. Be sure you understand why +// this is so before using this struct for anything else! +struct FarCompileStringsArgs { + const std::vector &in_fnames; + const string &out_fname; + const string &fst_type; + const FarType &far_type; + const int32 generate_keys; + const FarEntryType fet; + const FarTokenType tt; + const string &symbols_fname; + const string &unknown_symbol; + const bool keep_symbols; + const bool initial_symbols; + const bool allow_negative_labels; + const string &key_prefix; + const string &key_suffix; + + FarCompileStringsArgs(const std::vector &in_fnames, + const string &out_fname, const string &fst_type, + const FarType &far_type, int32 generate_keys, + FarEntryType fet, FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, + const string &key_prefix, const string &key_suffix) + : in_fnames(in_fnames), + out_fname(out_fname), + fst_type(fst_type), + far_type(far_type), + generate_keys(generate_keys), + fet(fet), + tt(tt), + symbols_fname(symbols_fname), + unknown_symbol(unknown_symbol), + keep_symbols(keep_symbols), + initial_symbols(initial_symbols), + allow_negative_labels(allow_negative_labels), + key_prefix(key_prefix), + key_suffix(key_suffix) {} +}; + +template +void FarCompileStrings(FarCompileStringsArgs *args) { + FarCompileStrings( + args->in_fnames, args->out_fname, args->fst_type, args->far_type, + args->generate_keys, args->fet, args->tt, args->symbols_fname, + args->unknown_symbol, args->keep_symbols, args->initial_symbols, + args->allow_negative_labels, args->key_prefix, args->key_suffix); +} + +void FarCompileStrings(const std::vector &in_fnames, + const string &out_fname, const string &arc_type, + const string &fst_type, const FarType &far_type, + int32 generate_keys, FarEntryType fet, FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, + const string &key_prefix, const string &key_suffix); + +// Note: it is safe to pass these strings as references because this struct is +// only used to pass them deeper in the call graph. Be sure you understand why +// this is so before using this struct for anything else! +struct FarCreateArgs { + const std::vector &in_fnames; + const string &out_fname; + const int32 generate_keys; + const FarType &far_type; + const string &key_prefix; + const string &key_suffix; + + FarCreateArgs(const std::vector &in_fnames, const string &out_fname, + const int32 generate_keys, const FarType &far_type, + const string &key_prefix, const string &key_suffix) + : in_fnames(in_fnames), + out_fname(out_fname), + generate_keys(generate_keys), + far_type(far_type), + key_prefix(key_prefix), + key_suffix(key_suffix) {} +}; + +template +void FarCreate(FarCreateArgs *args) { + FarCreate(args->in_fnames, args->out_fname, args->generate_keys, + args->far_type, args->key_prefix, args->key_suffix); +} + +void FarCreate(const std::vector &in_fnames, const string &out_fname, + const string &arc_type, const int32 generate_keys, + const FarType &far_type, const string &key_prefix, + const string &key_suffix); + +using FarEqualInnerArgs = std::tuple; + +using FarEqualArgs = WithReturnValue; + +template +void FarEqual(FarEqualArgs *args) { + args->retval = fst::FarEqual( + std::get<0>(args->args), std::get<1>(args->args), std::get<2>(args->args), + std::get<3>(args->args), std::get<4>(args->args)); +} + +bool FarEqual(const string &filename1, const string &filename2, + const string &arc_type, float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + +using FarExtractArgs = + std::tuple &, int32, const string &, + const string &, const string &, const string &, const string &>; + +template +void FarExtract(FarExtractArgs *args) { + fst::FarExtract(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args), std::get<5>(*args), + std::get<6>(*args)); +} + +void FarExtract(const std::vector &ifilenames, const string &arc_type, + int32 generate_filenames, const string &keys, + const string &key_separator, const string &range_delimiter, + const string &filename_prefix, const string &filename_suffix); + +using FarInfoArgs = std::tuple &, const string &, + const string &, const bool>; + +template +void FarInfo(FarInfoArgs *args) { + fst::FarInfo(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args)); +} + +void FarInfo(const std::vector &filenames, const string &arc_type, + const string &begin_key, const string &end_key, + const bool list_fsts); + +using GetFarInfoArgs = std::tuple &, const string &, + const string &, const bool, FarInfoData *>; + +template +void GetFarInfo(GetFarInfoArgs *args) { + fst::GetFarInfo(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args)); +} + +void GetFarInfo(const std::vector &filenames, const string &arc_type, + const string &begin_key, const string &end_key, + const bool list_fsts, FarInfoData *); + +using FarIsomorphicInnerArgs = std::tuple; + +using FarIsomorphicArgs = WithReturnValue; + +template +void FarIsomorphic(FarIsomorphicArgs *args) { + args->retval = fst::FarIsomorphic( + std::get<0>(args->args), std::get<1>(args->args), std::get<2>(args->args), + std::get<3>(args->args), std::get<4>(args->args)); +} + +bool FarIsomorphic(const string &filename1, const string &filename2, + const string &arc_type, float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + +struct FarPrintStringsArgs { + const std::vector &ifilenames; + const FarEntryType entry_type; + const FarTokenType token_type; + const string &begin_key; + const string &end_key; + const bool print_key; + const bool print_weight; + const string &symbols_fname; + const bool initial_symbols; + const int32 generate_filenames; + const string &filename_prefix; + const string &filename_suffix; + + FarPrintStringsArgs(const std::vector &ifilenames, + const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, + const bool print_weight, const string &symbols_fname, + const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix) + : ifilenames(ifilenames), + entry_type(entry_type), + token_type(token_type), + begin_key(begin_key), + end_key(end_key), + print_key(print_key), + print_weight(print_weight), + symbols_fname(symbols_fname), + initial_symbols(initial_symbols), + generate_filenames(generate_filenames), + filename_prefix(filename_prefix), + filename_suffix(filename_suffix) {} +}; + +template +void FarPrintStrings(FarPrintStringsArgs *args) { + fst::FarPrintStrings( + args->ifilenames, args->entry_type, args->token_type, args->begin_key, + args->end_key, args->print_key, args->print_weight, args->symbols_fname, + args->initial_symbols, args->generate_filenames, args->filename_prefix, + args->filename_suffix); +} + +void FarPrintStrings(const std::vector &ifilenames, + const string &arc_type, const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, + const bool print_weight, const string &symbols_fname, + const bool initial_symbols, const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix); + +} // namespace script +} // namespace fst + +#define REGISTER_FST_FAR_OPERATIONS(ArcType) \ + REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \ + REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \ + REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \ + REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \ + REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \ + REGISTER_FST_OPERATION(FarIsomorphic, ArcType, FarIsomorphicArgs); \ + REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs); \ + REGISTER_FST_OPERATION(GetFarInfo, ArcType, GetFarInfoArgs) + +#endif // FST_EXTENSIONS_FAR_FARSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/getters.h b/projects/llm_framework/include/fst/extensions/far/getters.h new file mode 100644 index 00000000..3dde4194 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/getters.h @@ -0,0 +1,30 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions for registering and invoking FAR main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_GETTERS_H_ +#define FST_EXTENSIONS_FAR_GETTERS_H_ + +#include +#include + +namespace fst { +namespace script { + +FarType GetFarType(const string &str); + +bool GetFarEntryType(const string &str, FarEntryType *entry_type); + +bool GetFarTokenType(const string &str, FarTokenType *token_type); + +void ExpandArgs(int argc, char **argv, int *argcp, char ***argvp); + +} // namespace script + +string GetFarTypeString(FarType type); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/info.h b/projects/llm_framework/include/fst/extensions/far/info.h new file mode 100644 index 00000000..0391c1f4 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/info.h @@ -0,0 +1,147 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_FAR_INFO_H_ +#define FST_EXTENSIONS_FAR_INFO_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fst { + +template +void AccumulateStatesAndArcs(const Fst &fst, size_t *nstate, size_t *narc, + size_t *nfinal) { + for (StateIterator> siter(fst); !siter.Done(); + siter.Next(), ++(*nstate)) { + ArcIterator> aiter(fst, siter.Value()); + for (; !aiter.Done(); aiter.Next(), ++(*narc)) { + } + if (fst.Final(siter.Value()) != Arc::Weight::Zero()) ++(*nfinal); + } +} + +struct KeyInfo { + string key; + string type; + size_t nstate = 0; + size_t narc = 0; + size_t nfinal = 0; +}; + +struct FarInfoData { + std::vector key_infos; + string far_type; + string arc_type; + size_t nfst = 0; + size_t nstate = 0; + size_t narc = 0; + size_t nfinal = 0; + std::set fst_types; +}; + +template +void GetFarInfo(const std::vector &filenames, const string &begin_key, + const string &end_key, const bool list_fsts, + FarInfoData *far_info) { + *far_info = FarInfoData(); + std::unique_ptr> reader(FarReader::Open(filenames)); + if (!reader) { + LOG(ERROR) << "GetFarInfo: failed to create far reader."; + return; + } + if (!begin_key.empty()) reader->Find(begin_key); + + for (; !reader->Done(); reader->Next()) { + const auto &key = reader->GetKey(); + if (!end_key.empty() && end_key < key) break; + ++far_info->nfst; + const auto *fst = reader->GetFst(); + far_info->fst_types.insert(fst->Type()); + if (list_fsts) { + KeyInfo info; + info.key = key; + info.type = fst->Type(); + AccumulateStatesAndArcs(*fst, &info.nstate, &info.narc, &info.nfinal); + far_info->nstate += info.nstate; + far_info->narc += info.narc; + far_info->nfinal += info.nfinal; + far_info->key_infos.push_back(info); + } else { + AccumulateStatesAndArcs(*fst, &far_info->nstate, &far_info->narc, + &far_info->nfinal); + } + } + far_info->far_type = GetFarTypeString(reader->Type()); + far_info->arc_type = Arc::Type(); +} + +template +void FarInfo(const std::vector &filenames, const string &begin_key, + const string &end_key, const bool list_fsts) { + FarInfoData info; + GetFarInfo(filenames, begin_key, end_key, list_fsts, &info); + if (!list_fsts) { + std::cout << std::left << std::setw(50) << "far type" << info.far_type + << std::endl; + std::cout << std::left << std::setw(50) << "arc type" << Arc::Type() + << std::endl; + std::cout << std::left << std::setw(50) << "fst type"; + for (auto iter = info.fst_types.begin(); iter != info.fst_types.end(); + ++iter) { + if (iter != info.fst_types.begin()) std::cout << ","; + std::cout << *iter; + } + std::cout << std::endl; + std::cout << std::left << std::setw(50) << "# of FSTs" << info.nfst + << std::endl; + std::cout << std::left << std::setw(50) << "total # of states" + << info.nstate << std::endl; + std::cout << std::left << std::setw(50) << "total # of arcs" << info.narc + << std::endl; + std::cout << std::left << std::setw(50) << "total # of final states" + << info.nfinal << std::endl; + } else { + // FIXME(kbg): Grok, then document this. + int wkey = 10; + int wtype = 10; + int wnstate = 14; + int wnarc = 12; + int wnfinal = 20; + for (const auto &key_info : info.key_infos) { + if (key_info.key.size() + 2 > wkey) wkey = key_info.key.size() + 2; + if (key_info.type.size() + 2 > wtype) wtype = key_info.type.size() + 2; + if (ceil(log10(key_info.nstate)) + 2 > wnstate) { + wnstate = ceil(log10(key_info.nstate)) + 2; + } + if (ceil(log10(key_info.narc)) + 2 > wnarc) { + wnarc = ceil(log10(key_info.narc)) + 2; + } + if (ceil(log10(key_info.nfinal)) + 2 > wnfinal) { + wnfinal = ceil(log10(key_info.nfinal)) + 2; + } + } + std::cout << std::left << std::setw(wkey) << "key" << std::setw(wtype) + << "type" << std::right << std::setw(wnstate) << "# of states" + << std::setw(wnarc) << "# of arcs" << std::setw(wnfinal) + << "# of final states" << std::endl; + for (const auto &key_info : info.key_infos) { + std::cout << std::left << std::setw(wkey) << key_info.key + << std::setw(wtype) << key_info.type << std::right + << std::setw(wnstate) << key_info.nstate << std::setw(wnarc) + << key_info.narc << std::setw(wnfinal) << key_info.nfinal + << std::endl; + } + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_INFO_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/isomorphic.h b/projects/llm_framework/include/fst/extensions/far/isomorphic.h new file mode 100644 index 00000000..1e6e9cb3 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/isomorphic.h @@ -0,0 +1,69 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_FAR_ISOMORPHIC_H_ +#define FST_EXTENSIONS_FAR_ISOMORPHIC_H_ + +#include +#include + +#include +#include + +namespace fst { + +template +bool FarIsomorphic(const string &filename1, const string &filename2, + float delta = kDelta, const string &begin_key = string(), + const string &end_key = string()) { + std::unique_ptr> reader1(FarReader::Open(filename1)); + if (!reader1) { + LOG(ERROR) << "FarIsomorphic: Cannot open FAR file " << filename1; + return false; + } + std::unique_ptr> reader2(FarReader::Open(filename2)); + if (!reader2) { + LOG(ERROR) << "FarIsomorphic: Cannot open FAR file " << filename2; + return false; + } + if (!begin_key.empty()) { + bool find_begin1 = reader1->Find(begin_key); + bool find_begin2 = reader2->Find(begin_key); + if (!find_begin1 || !find_begin2) { + bool ret = !find_begin1 && !find_begin2; + if (!ret) { + VLOG(1) << "FarIsomorphic: Key " << begin_key << " missing from " + << (find_begin1 ? "second" : "first") << " archive."; + } + return ret; + } + } + for (; !reader1->Done() && !reader2->Done(); + reader1->Next(), reader2->Next()) { + const auto &key1 = reader1->GetKey(); + const auto &key2 = reader2->GetKey(); + if (!end_key.empty() && end_key < key1 && end_key < key2) return true; + if (key1 != key2) { + LOG(ERROR) << "FarIsomorphic: Mismatched keys " << key1 << " and " + << key2; + return false; + } + if (!Isomorphic(*(reader1->GetFst()), *(reader2->GetFst()), delta)) { + LOG(ERROR) << "FarIsomorphic: FSTs for key " << key1 + << " are not isomorphic"; + return false; + } + } + if (!reader1->Done() || !reader2->Done()) { + LOG(ERROR) << "FarIsomorphic: Key " + << (reader1->Done() ? reader2->GetKey() : reader1->GetKey()) + << " missing form " << (reader2->Done() ? "first" : "second") + << " archive"; + return false; + } + return true; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/print-strings.h b/projects/llm_framework/include/fst/extensions/far/print-strings.h new file mode 100644 index 00000000..dc428401 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/print-strings.h @@ -0,0 +1,105 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Outputs as strings the string FSTs in a finite-state archive. + +#ifndef FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ +#define FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +DECLARE_string(far_field_separator); + +namespace fst { + +template +void FarPrintStrings(const std::vector &ifilenames, + FarEntryType entry_type, FarTokenType far_token_type, + const string &begin_key, const string &end_key, + bool print_key, bool print_weight, + const string &symbols_fname, bool initial_symbols, + int32 generate_filenames, const string &filename_prefix, + const string &filename_suffix) { + StringTokenType token_type; + if (far_token_type == FTT_SYMBOL) { + token_type = StringTokenType::SYMBOL; + } else if (far_token_type == FTT_BYTE) { + token_type = StringTokenType::BYTE; + } else if (far_token_type == FTT_UTF8) { + token_type = StringTokenType::UTF8; + } else { + FSTERROR() << "FarPrintStrings: Unknown token type"; + return; + } + std::unique_ptr syms; + if (!symbols_fname.empty()) { + // TODO(kbg): Allow negative flag? + const SymbolTableTextOptions opts(true); + syms.reset(SymbolTable::ReadText(symbols_fname, opts)); + if (!syms) { + LOG(ERROR) << "FarPrintStrings: Error reading symbol table " + << symbols_fname; + return; + } + } + std::unique_ptr> far_reader(FarReader::Open(ifilenames)); + if (!far_reader) return; + if (!begin_key.empty()) far_reader->Find(begin_key); + string okey; + int nrep = 0; + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + const auto &key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) break; + if (okey == key) { + ++nrep; + } else { + nrep = 0; + } + okey = key; + const auto *fst = far_reader->GetFst(); + if (i == 1 && initial_symbols && !syms && fst->InputSymbols()) + syms.reset(fst->InputSymbols()->Copy()); + string str; + VLOG(2) << "Handling key: " << key; + StringPrinter string_printer(token_type, + syms ? syms.get() : fst->InputSymbols()); + string_printer(*fst, &str); + if (entry_type == FET_LINE) { + if (print_key) std::cout << key << FLAGS_far_field_separator[0]; + std::cout << str; + if (print_weight) + std::cout << FLAGS_far_field_separator[0] << ShortestDistance(*fst); + std::cout << std::endl; + } else if (entry_type == FET_FILE) { + std::stringstream sstrm; + if (generate_filenames) { + sstrm.fill('0'); + sstrm << std::right << std::setw(generate_filenames) << i; + } else { + sstrm << key; + if (nrep > 0) sstrm << "." << nrep; + } + string filename; + filename = filename_prefix + sstrm.str() + filename_suffix; + std::ofstream ostrm(filename); + if (!ostrm) { + LOG(ERROR) << "FarPrintStrings: Can't open file: " << filename; + return; + } + ostrm << str; + if (token_type == StringTokenType::SYMBOL) ostrm << "\n"; + } + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/script-impl.h b/projects/llm_framework/include/fst/extensions/far/script-impl.h new file mode 100644 index 00000000..a0586cc3 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/script-impl.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions for registering and invoking Far main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ +#define FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ + +#include + +#include +namespace fst { +namespace script { + +string LoadArcTypeFromFar(const string &far_fname); + +string LoadArcTypeFromFst(const string &fst_fname); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/stlist.h b/projects/llm_framework/include/fst/extensions/far/stlist.h new file mode 100644 index 00000000..b155e17e --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/stlist.h @@ -0,0 +1,273 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A generic (string,type) list file format. +// +// This is a stripped-down version of STTable that does not support the Find() +// operation but that does support reading/writting from standard in/out. + +#ifndef FST_EXTENSIONS_FAR_STLIST_H_ +#define FST_EXTENSIONS_FAR_STLIST_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fst { + +static constexpr int32 kSTListMagicNumber = 5656924; +static constexpr int32 kSTListFileVersion = 1; + +// String-type list writing class for object of type T using a functor Writer. +// The Writer functor must provide at least the following interface: +// +// struct Writer { +// void operator()(std::ostream &, const T &) const; +// }; +template +class STListWriter { + public: + explicit STListWriter(const string &filename) + : stream_(filename.empty() ? &std::cout : new std::ofstream( + filename, + std::ios_base::out | + std::ios_base::binary)), + error_(false) { + WriteType(*stream_, kSTListMagicNumber); + WriteType(*stream_, kSTListFileVersion); + if (!stream_) { + FSTERROR() << "STListWriter::STListWriter: Error writing to file: " + << filename; + error_ = true; + } + } + + static STListWriter *Create(const string &filename) { + return new STListWriter(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STListWriter::Add: Key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STListWriter::Add: Key out of order: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + WriteType(*stream_, key); + entry_writer_(*stream_, t); + } + + bool Error() const { return error_; } + + ~STListWriter() { + WriteType(*stream_, string()); + if (stream_ != &std::cout) delete stream_; + } + + private: + Writer entry_writer_; + std::ostream *stream_; // Output stream. + string last_key_; // Last key. + bool error_; + + STListWriter(const STListWriter &) = delete; + STListWriter &operator=(const STListWriter &) = delete; +}; + +// String-type list reading class for object of type T using a functor Reader. +// Reader must provide at least the following interface: +// +// struct Reader { +// T *operator()(std::istream &) const; +// }; +template +class STListReader { + public: + explicit STListReader(const std::vector &filenames) + : sources_(filenames), error_(false) { + streams_.resize(filenames.size(), 0); + bool has_stdin = false; + for (size_t i = 0; i < filenames.size(); ++i) { + if (filenames[i].empty()) { + if (!has_stdin) { + streams_[i] = &std::cin; + sources_[i] = "stdin"; + has_stdin = true; + } else { + FSTERROR() << "STListReader::STListReader: Cannot read multiple " + << "inputs from standard input"; + error_ = true; + return; + } + } else { + streams_[i] = new std::ifstream( + filenames[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "STListReader::STListReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + int32 magic_number = 0; + ReadType(*streams_[i], &magic_number); + int32 file_version = 0; + ReadType(*streams_[i], &file_version); + if (magic_number != kSTListMagicNumber) { + FSTERROR() << "STListReader::STListReader: Wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTListFileVersion) { + FSTERROR() << "STListReader::STListReader: Wrong file version: " + << filenames[i]; + error_ = true; + return; + } + string key; + ReadType(*streams_[i], &key); + if (!key.empty()) heap_.push(std::make_pair(key, i)); + if (!*streams_[i]) { + FSTERROR() << "STListReader: Error reading file: " << sources_[i]; + error_ = true; + return; + } + } + if (heap_.empty()) return; + const auto current = heap_.top().second; + entry_.reset(entry_reader_(*streams_[current])); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: Error reading entry for key " + << heap_.top().first << ", file " << sources_[current]; + error_ = true; + } + } + + ~STListReader() { + for (auto &stream : streams_) { + if (stream != &std::cin) delete stream; + } + } + + static STListReader *Open(const string &filename) { + std::vector filenames; + filenames.push_back(filename); + return new STListReader(filenames); + } + + static STListReader *Open(const std::vector &filenames) { + return new STListReader(filenames); + } + + void Reset() { + FSTERROR() << "STListReader::Reset: Operation not supported"; + error_ = true; + } + + bool Find(const string &key) { + FSTERROR() << "STListReader::Find: Operation not supported"; + error_ = true; + return false; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + auto current = heap_.top().second; + string key; + heap_.pop(); + ReadType(*(streams_[current]), &key); + if (!*streams_[current]) { + FSTERROR() << "STListReader: Error reading file: " << sources_[current]; + error_ = true; + return; + } + if (!key.empty()) heap_.push(std::make_pair(key, current)); + if (!heap_.empty()) { + current = heap_.top().second; + entry_.reset(entry_reader_(*streams_[current])); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: Error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + } + + const string &GetKey() const { return heap_.top().first; } + + const T *GetEntry() const { return entry_.get(); } + + bool Error() const { return error_; } + + private: + Reader entry_reader_; // Read functor. + std::vector streams_; // Input streams. + std::vector sources_; // Corresponding filenames. + std::priority_queue< + std::pair, std::vector>, + std::greater>> heap_; // (Key, stream id) heap + mutable std::unique_ptr entry_; // The currently read entry. + bool error_; + + STListReader(const STListReader &) = delete; + STListReader &operator=(const STListReader &) = delete; +}; + +// String-type list header reading function, templated on the entry header type. +// The Header type must provide at least the following interface: +// +// struct Header { +// void Read(std::istream &strm, const string &filename); +// }; +template +bool ReadSTListHeader(const string &filename, Header *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTListHeader: Can't read header from standard input"; + return false; + } + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: Could not open file: " << filename; + return false; + } + int32 magic_number = 0; + ReadType(strm, &magic_number); + int32 file_version = 0; + ReadType(strm, &file_version); + if (magic_number != kSTListMagicNumber) { + LOG(ERROR) << "ReadSTListHeader: Wrong file type: " << filename; + return false; + } + if (file_version != kSTListFileVersion) { + LOG(ERROR) << "ReadSTListHeader: Wrong file version: " << filename; + return false; + } + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: Error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTList(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STLIST_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/sttable.h b/projects/llm_framework/include/fst/extensions/far/sttable.h new file mode 100644 index 00000000..2a01bb16 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/sttable.h @@ -0,0 +1,353 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A generic string-to-type table file format. +// +// This is not meant as a generalization of SSTable. This is more of a simple +// replacement for SSTable in order to provide an open-source implementation +// of the FAR format for the external version of the FST library. + +#ifndef FST_EXTENSIONS_FAR_STTABLE_H_ +#define FST_EXTENSIONS_FAR_STTABLE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +static constexpr int32 kSTTableMagicNumber = 2125656924; +static constexpr int32 kSTTableFileVersion = 1; + +// String-type table writing class for an object of type T using a functor +// Writer. The Writer functor must provide at least the following interface: +// +// struct Writer { +// void operator()(std::ostream &, const T &) const; +// }; +template +class STTableWriter { + public: + explicit STTableWriter(const string &filename) + : stream_(filename, std::ios_base::out | std::ios_base::binary), + error_(false) { + WriteType(stream_, kSTTableMagicNumber); + WriteType(stream_, kSTTableFileVersion); + if (stream_.fail()) { + FSTERROR() << "STTableWriter::STTableWriter: Error writing to file: " + << filename; + error_ = true; + } + } + + static STTableWriter *Create(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableWriter: Writing to standard out unsupported."; + return nullptr; + } + return new STTableWriter(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STTableWriter::Add: Key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STTableWriter::Add: Key out of order: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + positions_.push_back(stream_.tellp()); + WriteType(stream_, key); + entry_writer_(stream_, t); + } + + bool Error() const { return error_; } + + ~STTableWriter() { + WriteType(stream_, positions_); + WriteType(stream_, static_cast(positions_.size())); + } + + private: + Writer entry_writer_; + std::ofstream stream_; + std::vector positions_; // Position in file of each key-entry pair. + string last_key_; // Last key. + bool error_; + + STTableWriter(const STTableWriter &) = delete; + STTableWriter &operator=(const STTableWriter &) = delete; +}; + +// String-type table reading class for object of type T using a functor Reader. +// Reader must provide at least the following interface: +// +// struct Reader { +// T *operator()(std::istream &) const; +// }; +// +template +class STTableReader { + public: + explicit STTableReader(const std::vector &filenames) + : sources_(filenames), error_(false) { + compare_.reset(new Compare(&keys_)); + keys_.resize(filenames.size()); + streams_.resize(filenames.size(), 0); + positions_.resize(filenames.size()); + for (size_t i = 0; i < filenames.size(); ++i) { + streams_[i] = new std::ifstream( + filenames[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader::STTableReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + int32 magic_number = 0; + ReadType(*streams_[i], &magic_number); + int32 file_version = 0; + ReadType(*streams_[i], &file_version); + if (magic_number != kSTTableMagicNumber) { + FSTERROR() << "STTableReader::STTableReader: Wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTTableFileVersion) { + FSTERROR() << "STTableReader::STTableReader: Wrong file version: " + << filenames[i]; + error_ = true; + return; + } + int64 num_entries; + streams_[i]->seekg(-static_cast(sizeof(int64)), std::ios_base::end); + ReadType(*streams_[i], &num_entries); + if (num_entries > 0) { + streams_[i]->seekg(-static_cast(sizeof(int64)) * (num_entries + 1), + std::ios_base::end); + positions_[i].resize(num_entries); + for (size_t j = 0; (j < num_entries) && (!streams_[i]->fail()); ++j) { + ReadType(*streams_[i], &(positions_[i][j])); + } + streams_[i]->seekg(positions_[i][0]); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader::STTableReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + } + MakeHeap(); + } + + ~STTableReader() { + for (auto &stream : streams_) delete stream; + } + + static STTableReader *Open(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableReader: Operation not supported on standard input"; + return nullptr; + } + std::vector filenames; + filenames.push_back(filename); + return new STTableReader(filenames); + } + + static STTableReader *Open(const std::vector &filenames) { + return new STTableReader(filenames); + } + + void Reset() { + if (error_) return; + for (size_t i = 0; i < streams_.size(); ++i) + streams_[i]->seekg(positions_[i].front()); + MakeHeap(); + } + + bool Find(const string &key) { + if (error_) return false; + for (size_t i = 0; i < streams_.size(); ++i) LowerBound(i, key); + MakeHeap(); + if (heap_.empty()) return false; + return keys_[current_] == key; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + if (streams_[current_]->tellg() <= positions_[current_].back()) { + ReadType(*(streams_[current_]), &(keys_[current_])); + if (streams_[current_]->fail()) { + FSTERROR() << "STTableReader: Error reading file: " + << sources_[current_]; + error_ = true; + return; + } + std::push_heap(heap_.begin(), heap_.end(), *compare_); + } else { + heap_.pop_back(); + } + if (!heap_.empty()) PopHeap(); + } + + const string &GetKey() const { return keys_[current_]; } + + const T *GetEntry() const { return entry_.get(); } + + bool Error() const { return error_; } + + private: + // Comparison functor used to compare stream IDs in the heap. + struct Compare { + explicit Compare(const std::vector *keys) : keys(keys) {} + + bool operator()(size_t i, size_t j) const { + return (*keys)[i] > (*keys)[j]; + }; + + private: + const std::vector *keys; + }; + + // Positions the stream at the position corresponding to the lower bound for + // the specified key. + void LowerBound(size_t id, const string &find_key) { + auto *strm = streams_[id]; + const auto &positions = positions_[id]; + if (positions.empty()) return; + size_t low = 0; + size_t high = positions.size() - 1; + while (low < high) { + size_t mid = (low + high) / 2; + strm->seekg(positions[mid]); + string key; + ReadType(*strm, &key); + if (key > find_key) { + high = mid; + } else if (key < find_key) { + low = mid + 1; + } else { + for (size_t i = mid; i > low; --i) { + strm->seekg(positions[i - 1]); + ReadType(*strm, &key); + if (key != find_key) { + strm->seekg(positions[i]); + return; + } + } + strm->seekg(positions[low]); + return; + } + } + strm->seekg(positions[low]); + } + + // Adds all streams to the heap. + void MakeHeap() { + heap_.clear(); + for (size_t i = 0; i < streams_.size(); ++i) { + if (positions_[i].empty()) continue; + ReadType(*streams_[i], &(keys_[i])); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader: Error reading file: " << sources_[i]; + error_ = true; + return; + } + heap_.push_back(i); + } + if (heap_.empty()) return; + std::make_heap(heap_.begin(), heap_.end(), *compare_); + PopHeap(); + } + + // Positions the stream with the lowest key at the top of the heap, sets + // current_ to the ID of that stream, and reads the current entry from that + // stream. + void PopHeap() { + std::pop_heap(heap_.begin(), heap_.end(), *compare_); + current_ = heap_.back(); + entry_.reset(entry_reader_(*streams_[current_])); + if (!entry_) error_ = true; + if (streams_[current_]->fail()) { + FSTERROR() << "STTableReader: Error reading entry for key: " + << keys_[current_] << ", file: " << sources_[current_]; + error_ = true; + } + } + + Reader entry_reader_; + std::vector streams_; // Input streams. + std::vector sources_; // Corresponding file names. + std::vector> positions_; // Index of positions. + std::vector keys_; // Lowest unread key for each stream. + std::vector heap_; // Heap containing ID of streams with unread keys. + int64 current_; // ID of current stream to be read. + std::unique_ptr compare_; // Functor comparing stream IDs. + mutable std::unique_ptr entry_; // The currently read entry. + bool error_; +}; + +// String-type table header reading function template on the entry header type. +// The Header type must provide at least the following interface: +// +// struct Header { +// void Read(std::istream &istrm, const string &filename); +// }; +template +bool ReadSTTableHeader(const string &filename, Header *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTTable: Can't read header from standard input"; + return false; + } + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: Could not open file: " << filename; + return false; + } + int32 magic_number = 0; + ReadType(strm, &magic_number); + int32 file_version = 0; + ReadType(strm, &file_version); + if (magic_number != kSTTableMagicNumber) { + LOG(ERROR) << "ReadSTTableHeader: Wrong file type: " << filename; + return false; + } + if (file_version != kSTTableFileVersion) { + LOG(ERROR) << "ReadSTTableHeader: Wrong file version: " << filename; + return false; + } + int64 i = -1; + strm.seekg(-static_cast(sizeof(int64)), std::ios_base::end); + ReadType(strm, &i); // Reads number of entries + if (strm.fail()) { + LOG(ERROR) << "ReadSTTableHeader: Error reading file: " << filename; + return false; + } + if (i == 0) return true; // No entry header to read. + strm.seekg(-2 * static_cast(sizeof(int64)), std::ios_base::end); + ReadType(strm, &i); // Reads position for last entry in file. + strm.seekg(i); + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (strm.fail()) { + LOG(ERROR) << "ReadSTTableHeader: Error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTTable(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STTABLE_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h new file mode 100644 index 00000000..a6ac7279 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h @@ -0,0 +1,1074 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ +#define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace fst { + +// Forward declaration +template +class FeatureGroupBuilder; + +// For logging purposes +inline string TranslateLabel(int64 label, const SymbolTable *syms); +template +string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms); +template +string JoinLabels(const std::vector *Dump(); + + private: + bool error_; + CompactSet all_output_labels_; + std::map> word_output_map_, word_feat_map_; + std::map> feat_groups_; + std::vector>> groups_; + size_t max_future_size_; + Label max_input_label_; + const SymbolTable *isyms_, *fsyms_, *osyms_; + + LinearFstDataBuilder(const LinearFstDataBuilder &) = delete; + LinearFstDataBuilder &operator=(const LinearFstDataBuilder &) = delete; +}; + +// Builds a LinearFstData tailored for a LinearClassifierFst. The +// major difference between an ordinary LinearFstData that works on +// taggers and a LinearFstData that works on classifiers is that +// feature groups are divided into sections by the prediction class +// label. For a prediction label `pred` and a logical group id +// `group`, the actual group id is `group * num_classes + pred - +// 1`. +// +// This layout saves us from recording output labels in each single +// FeatureGroup. Because there is no need for any delaying, stripping +// the output allows features with different shapes but using the same +// set of feature label mapping to reside in a single FeatureGroup. +template +class LinearClassifierFstDataBuilder { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Constructs a builder for a `num_classes`-class classifier, + // optinally with associated symbol tables for diagnostic + // output. The output labels (i.e. prediction) must be in the range + // of [1, num_classes]. + explicit LinearClassifierFstDataBuilder(size_t num_classes, + const SymbolTable *isyms = nullptr, + const SymbolTable *fsyms = nullptr, + const SymbolTable *osyms = nullptr) + : error_(false), + num_classes_(num_classes), + num_groups_(0), + builder_(isyms, fsyms, osyms) {} + + // Tests whether the builder has encountered any error. Similar to + // LinearFstDataBuilder<>::Error(). + bool Error() const { return error_; } + + // Same as LinearFstDataBuilder<>::AddWord(). + bool AddWord(Label word, const std::vector *Dump(); + + private: + std::vector builder_; +}; + +// Builds a single feature group. Usually used in +// `LinearFstDataBuilder::AddWeight()`. See that method for the +// constraints on grouping features. +template +class FeatureGroupBuilder { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Constructs a builder with the given future size. All features + // added to the group will have look-ahead windows of this size. + FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms, + const SymbolTable *osyms) + : error_(false), future_size_(future_size), fsyms_(fsyms), osyms_(osyms) { + // This edge is special; see doc of class `FeatureGroup` on the + // details. + start_ = trie_.Insert(trie_.Root(), InputOutputLabel(kNoLabel, kNoLabel)); + } + + // Tests whether the builder has encountered any error. No operation + // is valid if the builder is already at error state. All other + // public methods should check this before any actual operations. + bool Error() const { return error_; } + + // Adds a feature weight with the given context. Returns true iff + // the weight is added. A weight is not added if it has ill-formed + // context involving start-, end-of-sentence marks. + // + // Note: `input` is the sequence of input + // features, instead of input labels themselves. `input` must be at + // least as long as `future_size`; `output` may be empty, but + // usually should be non-empty because an empty output context is + // useless in discriminative modelling. All labels in both `input` + // and `output` must be > 0 (this is checked in + // `LinearFstDataBuilder::AddWeight()`). See + // LinearFstDataBuilder<>::AddWeight for more details. + // + // This may fail if the input is smaller than the look-ahead window. + bool AddWeight(const std::vector *Dump(size_t max_future_size); + + private: + typedef typename FeatureGroup::InputOutputLabel InputOutputLabel; + typedef typename FeatureGroup::InputOutputLabelHash InputOutputLabelHash; + typedef typename FeatureGroup::WeightBackLink WeightBackLink; + // Nested trie topology uses more memory but we can traverse a + // node's children easily, which is required in `BuildBackLinks()`. + typedef NestedTrieTopology Topology; + typedef MutableTrie Trie; + + // Finds the first node with an arc with `label` following the + // back-off chain of `parent`. Returns the node index or + // `kNoTrieNodeId` when not found. The number of hops is stored in + // `hop` when it is not `nullptr`. + // + // This does not fail. + int FindFirstMatch(InputOutputLabel label, int parent, int *hop) const; + + // Links each node to its immediate back-off. root is linked to -1. + // + // This may fail when the unique immediate back-off constraint is + // violated. + void BuildBackLinks(); + + // Traces back on the back-chain for each node to multiply the + // weights from back-offs to the node itself. + // + // This does not fail. + void PreAccumulateWeights(); + + // Reconstruct the path from trie root to given node for logging. + bool TrieDfs(const Topology &topology, int cur, int target, + std::vector *path) const; + string TriePath(int node, const Topology &topology) const; + + bool error_; + size_t future_size_; + Trie trie_; + int start_; + const SymbolTable *fsyms_, *osyms_; + + FeatureGroupBuilder(const FeatureGroupBuilder &) = delete; + FeatureGroupBuilder &operator=(const FeatureGroupBuilder &) = delete; +}; + +// +// Implementation of methods in `LinearFstDataBuilder` +// +template +bool LinearFstDataBuilder::AddWord(Label word, + const std::vector::kStartOfSentence || + word == LinearFstData::kEndOfSentence) { + LOG(WARNING) << "Ignored: adding boundary label: " + << TranslateLabel(word, isyms_) + << "(start-of-sentence=" << LinearFstData::kStartOfSentence + << ", end-of-sentence=" << LinearFstData::kEndOfSentence + << ")"; + return false; + } + if (word <= 0) { + error_ = true; + FSTERROR() << "Word label must be > 0; got " << word; + return false; + } + if (word > max_input_label_) max_input_label_ = word; + // Make sure the word hasn't been added before + if (word_feat_map_.find(word) != word_feat_map_.end()) { + error_ = true; + FSTERROR() << "Input word " << TranslateLabel(word, isyms_) + << " is added twice"; + return false; + } + // Store features + std::set::AddWord( + Label word, const std::vector::kStartOfSentence || + output == LinearFstData::kEndOfSentence) { + LOG(WARNING) << "Ignored: word = " << TranslateLabel(word, isyms_) + << ": adding boundary label as possible output: " << output + << "(start-of-sentence=" + << LinearFstData::kStartOfSentence + << ", end-of-sentence=" << LinearFstData::kEndOfSentence + << ")"; + continue; + } + if (output <= 0) { + error_ = true; + FSTERROR() << "Output label must be > 0; got " << output; + return false; + } + outputs->insert(output); + all_output_labels_.Insert(output); + } + return true; +} + +template +inline int LinearFstDataBuilder::AddGroup(size_t future_size) { + if (error_) { + FSTERROR() << "Calling LinearFstDataBuilder<>::AddGroup() at error state"; + return -1; + } + size_t ret = groups_.size(); + groups_.emplace_back(new FeatureGroupBuilder(future_size, fsyms_, osyms_)); + if (future_size > max_future_size_) max_future_size_ = future_size; + return ret; +} + +template +bool LinearFstDataBuilder::AddWeight(size_t group, + const std::vector::kStartOfSentence && + input[i - 1] != LinearFstData::kStartOfSentence) + start_in_middle = true; + if (input[i - 1] == LinearFstData::kEndOfSentence && + input[i] != LinearFstData::kEndOfSentence) + end_in_middle = true; + } + if (start_in_middle) { + LOG(WARNING) << "Ignored: start-of-sentence in the middle of the input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (end_in_middle) { + LOG(WARNING) << "Ignored: end-of-sentence in the middle of the input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + } + // Check well-formedness of boundary marks on the output. + { + bool non_first_start = false, non_last_end = false; + for (int i = 1; i < output.size(); ++i) { + if (output[i] == LinearFstData::kStartOfSentence) + non_first_start = true; + if (output[i - 1] == LinearFstData::kEndOfSentence) + non_last_end = true; + } + if (non_first_start) { + LOG(WARNING) << "Ignored: start-of-sentence not appearing " + << "as the first label in the output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (non_last_end) { + LOG(WARNING) << "Ignored: end-of-sentence not appearing " + << "as the last label in the output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + } + + for (size_t i = 0; i < input.size(); ++i) { + Label feat = input[i]; + if (feat != LinearFstData::kStartOfSentence && + feat != LinearFstData::kEndOfSentence && feat <= 0) { + error_ = true; + FSTERROR() << "Feature label must be > 0; got " << feat; + return false; + } + feat_groups_[feat].insert(group); + } + for (size_t i = 0; i < output.size(); ++i) { + Label label = output[i]; + if (label != LinearFstData::kStartOfSentence && + label != LinearFstData::kEndOfSentence && label <= 0) { + error_ = true; + FSTERROR() << "Output label must be > 0; got " << label; + return false; + } + if (label != LinearFstData::kStartOfSentence && + label != LinearFstData::kEndOfSentence) + all_output_labels_.Insert(label); + } + + // Everything looks good at this point (more checks on the way in + // the feature group). Add this feature weight. + bool added = groups_[group]->AddWeight(input, output, weight); + if (groups_[group]->Error()) { + error_ = true; + FSTERROR() << "FeatureGroupBuilder<>::AddWeight() failed"; + return false; + } + return added; +} + +template +LinearFstData *LinearFstDataBuilder::Dump() { + if (error_) { + FSTERROR() << "Calling LinearFstDataBuilder<>::Dump() at error state"; + return nullptr; + } + + std::unique_ptr> data(new LinearFstData()); + data->max_future_size_ = max_future_size_; + data->max_input_label_ = max_input_label_; + + // Feature groups; free builders after it's dumped. + data->groups_.resize(groups_.size()); + for (int group = 0; group != groups_.size(); ++group) { + FeatureGroup *new_group = groups_[group]->Dump(max_future_size_); + if (new_group == nullptr) { + error_ = true; + FSTERROR() << "Error in dumping group " << group; + return nullptr; + } + data->groups_[group].reset(new_group); + groups_[group].reset(); + VLOG(1) << "Group " << group << ": " << new_group->Stats(); + } + + // Per-group feature mapping + data->group_feat_map_.Init(data->NumGroups(), max_input_label_ + 1); + for (Label word = 1; word <= max_input_label_; ++word) { + typename std::map>::const_iterator it = + word_feat_map_.find(word); + if (it == word_feat_map_.end()) continue; + for (typename std::set::AddWord( + Label word, const std::vector::AddGroup() { + if (error_) { + FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddGroup() at " + "error state"; + return -1; + } + for (int i = 0; i < num_classes_; ++i) builder_.AddGroup(0); + if (builder_.Error()) { + error_ = true; + return -1; + } + return num_groups_++; +} + +template +inline bool LinearClassifierFstDataBuilder::AddWeight( + size_t group, const std::vector *LinearClassifierFstDataBuilder::Dump() { + if (error_) { + FSTERROR() + << "Calling LinearClassifierFstDataBuilder<>::Dump() at error state"; + return nullptr; + } + LinearFstData *data = builder_.Dump(); + error_ = true; + return data; +} + +// +// Implementation of methods in `FeatureGroupBuilder` +// +template +bool FeatureGroupBuilder::AddWeight(const std::vector::kStartOfSentence) + ++num_input_start; + int num_output_start = 0; + while (num_output_start < output.size() && + output[num_output_start] == LinearFstData::kStartOfSentence) + ++num_output_start; + int num_input_end = 0; + for (int i = input.size() - 1; + i >= 0 && input[i] == LinearFstData::kEndOfSentence; --i) + ++num_input_end; + int num_output_end = 0; + for (int i = output.size() - 1; + i >= 0 && output[i] == LinearFstData::kEndOfSentence; --i) + ++num_output_end; + + DCHECK_LE(num_output_end, 1); + + if (input.size() - num_input_start < future_size_) { + LOG(WARNING) << "Ignored: start-of-sentence in the future!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, fsyms_); + return false; + } + if (num_input_start > 0 && + input.size() - future_size_ - num_input_start < + output.size() - num_output_start) { + LOG(WARNING) << "Ignored: matching start-of-sentence with actual output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (num_output_start > 0 && + input.size() - future_size_ - num_input_start > + output.size() - num_output_start) { + LOG(WARNING) << "Ignored: matching start-of-sentence with actual input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + // The following two require `num_output_end` <= 1. + if (num_input_end > future_size_ && num_input_end - future_size_ != 1) { + LOG(WARNING) << "Ignored: matching end-of-sentence with actual output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (num_output_end > 0 && + ((input.size() == future_size_ && future_size_ != num_input_end) || + (input.size() > future_size_ && + num_input_end != future_size_ + num_output_end))) { + LOG(WARNING) << "Ignored: matching end-of-sentence with actual input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + // Check if the context has no other labels than boundary marks + // (such features are useless). + if (num_input_start + num_input_end == input.size() && + num_output_start + num_output_end == output.size()) { + LOG(WARNING) + << "Ignored: feature context consisting of only boundary marks!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + + // Start point for insertion in the trie. Insert at `start_` iff the + // beginning of the context is non-consumed start-of-sentence. + int cur = (num_input_start == 0 && num_output_start <= future_size_) + ? trie_.Root() + : start_; + // Skip all input start-of-sentence marks + size_t ipos = num_input_start; + // Skip to keep at most `future_size_` start-of-sentence marks + size_t opos = + num_output_start <= future_size_ ? 0 : num_output_start - future_size_; + // Skip `num_output_end` end-of-sentence marks on both input and output + size_t iend = !input.empty() ? input.size() - num_output_end : 0, + oend = output.size() - num_output_end; + // Further, when output is empty, keep at most `future_size_` + // end-of-sentence marks on input. + if (output.empty() && num_input_end > future_size_) + iend = input.size() - num_input_end + future_size_; + + // Actual feature context is (input[ipos:iend], output[opos:oend]). + + // Pad `kNoLabel` as don't cares on the shorter of actual `input` + // and `output`. + const size_t effective_input_size = iend - ipos, + effective_output_size = oend - opos; + if (effective_input_size > effective_output_size) { + for (size_t pad = effective_input_size - effective_output_size; pad != 0; + --pad, ++ipos) + cur = trie_.Insert(cur, InputOutputLabel(input[ipos], kNoLabel)); + } else if (effective_input_size < effective_output_size) { + for (size_t pad = effective_output_size - effective_input_size; pad != 0; + --pad, ++opos) + cur = trie_.Insert(cur, InputOutputLabel(kNoLabel, output[opos])); + } + CHECK_EQ(iend - ipos, oend - opos); + for (; ipos != iend; ++ipos, ++opos) + cur = trie_.Insert(cur, InputOutputLabel(input[ipos], output[opos])); + // We only need to attach final weight when there is an output + // end-of-sentence. When there is only end-of-sentence on the input, + // they are all consumed as the end-of-sentence paddings from + // `LinearFstImpl<>::ShiftBuffer()`. `LinearFstImpl<>::Expand()` + // and `LinearFstImpl<>::MatchInput()` ensures no other + // transition takes place after consuming the padding. + if (num_output_end > 0 || (output.empty() && num_input_end > future_size_)) + trie_[cur].final_weight = Times(trie_[cur].final_weight, weight); + else + trie_[cur].weight = Times(trie_[cur].weight, weight); + + return true; +} + +template +FeatureGroup *FeatureGroupBuilder::Dump(size_t max_future_size) { + if (error_) { + FSTERROR() << "Calling FeatureGroupBuilder<>::PreAccumulateWeights() " + << "at error state"; + return nullptr; + } + + if (max_future_size < future_size_) { + error_ = true; + FSTERROR() << "max_future_size (= " << max_future_size + << ") is smaller the builder's future_size (= " << future_size_ + << ")"; + return nullptr; + } + + BuildBackLinks(); + if (error_) return nullptr; + PreAccumulateWeights(); // does not fail + + FeatureGroup *ret = + new FeatureGroup(max_future_size - future_size_, start_); + + // Walk around the trie to compute next states + ret->next_state_.resize(trie_.NumNodes()); + const Topology &topology = trie_.TrieTopology(); + for (int i = 0; i < topology.NumNodes(); ++i) { + int next = i; + while (next != topology.Root() && topology.ChildrenOf(next).empty() && + trie_[next].final_weight == + trie_[trie_[next].back_link].final_weight) + next = trie_[next].back_link; + ret->next_state_[i] = next; + } + + // Copy the trie + typename FeatureGroup::Trie store_trie(trie_); + ret->trie_.swap(store_trie); + + // Put the builder at error state to prevent repeated call of `Dump()`. + error_ = true; + return ret; +} + +template +int FeatureGroupBuilder::FindFirstMatch(InputOutputLabel label, int parent, + int *hop) const { + int hop_count = 0; + int ret = kNoTrieNodeId; + for (; parent >= 0; parent = trie_[parent].back_link, ++hop_count) { + int next = trie_.Find(parent, label); + if (next != kNoTrieNodeId) { + ret = next; + break; + } + } + if (hop != nullptr) *hop = hop_count; + return ret; +} + +template +void FeatureGroupBuilder::BuildBackLinks() { + // Breadth first search from the root. In the case where we only + // have the input label, the immedate back-off is simply the longest + // suffix of the current node that is also in the trie. For a node + // reached from its parent with label L, we can simply walk through + // the parent's back-off chain to find the first state with an arc + // of the same label L. The uniqueness is always + // guanranteed. However, in the case with both input and output + // labels, it is possible to back off by removing first labels from + // either side, which in general causes non-uniqueness. + + const Topology &topology = trie_.TrieTopology(); + std::queue q; // all enqueued or visited nodes have known links + + // Note: nodes have back link initialized to -1 in their + // constructor. + q.push(trie_.Root()); + while (!error_ && !q.empty()) { + int parent = q.front(); + q.pop(); + // Find links for every child + const typename Topology::NextMap &children = topology.ChildrenOf(parent); + for (typename Topology::NextMap::const_iterator eit = children.begin(); + eit != children.end(); ++eit) { + const std::pair &edge = *eit; + InputOutputLabel label = edge.first; + int child = edge.second; + if (label.input == kNoLabel || label.output == kNoLabel) { + // Label pairs from root to here all have one and only one + // `kNoLabel` on the same side; equivalent to the + // "longest-suffix" case. + trie_[child].back_link = + FindFirstMatch(label, trie_[parent].back_link, nullptr); + } else { + // Neither side is `kNoLabel` at this point, there are + // three possible ways to back-off: if the parent backs + // off to some context with only one side non-empty, the + // empty side may remain empty; or else an exact match of + // both sides is needed. Try to find all three possible + // backs and look for the closest one (in terms of hops + // along the parent's back-off chain). + int only_input_hop, only_output_hop, full_hop; + int only_input_link = + FindFirstMatch(InputOutputLabel(label.input, kNoLabel), parent, + &only_input_hop), + only_output_link = + FindFirstMatch(InputOutputLabel(kNoLabel, label.output), parent, + &only_output_hop), + full_link = + FindFirstMatch(label, trie_[parent].back_link, &full_hop); + if (only_input_link != -1 && only_output_link != -1) { + error_ = true; + FSTERROR() << "Branching back-off chain:\n" + << "\tnode " << child << ": " << TriePath(child, topology) + << "\n" + << "\tcan back-off to node " << only_input_link << ": " + << TriePath(only_input_link, topology) << "\n" + << "\tcan back-off to node " << only_output_link << ": " + << TriePath(only_output_link, topology); + return; + } else if (full_link != -1) { + ++full_hop; + if (full_hop <= only_input_hop && full_hop <= only_output_hop) { + trie_[child].back_link = full_link; + } else { + error_ = true; + int problem_link = only_input_link != kNoTrieNodeId + ? only_input_link + : only_output_link; + CHECK_NE(problem_link, kNoTrieNodeId); + FSTERROR() << "Branching back-off chain:\n" + << "\tnode " << child << ": " + << TriePath(child, topology) << "\n" + << "\tcan back-off to node " << full_link << ": " + << TriePath(full_link, topology) << "\n" + << "tcan back-off to node " << problem_link << ": " + << TriePath(problem_link, topology); + return; + } + } else { + trie_[child].back_link = + only_input_link != -1 ? only_input_link : only_output_link; + } + } + if (error_) break; + // Point to empty context (root) when no back-off can be found + if (trie_[child].back_link == -1) trie_[child].back_link = 0; + q.push(child); + } + } +} + +template +void FeatureGroupBuilder::PreAccumulateWeights() { + std::vector visited(trie_.NumNodes(), false); + visited[trie_.Root()] = true; + + for (size_t i = 0; i != trie_.NumNodes(); ++i) { + std::stack back_offs; + for (int j = i; !visited[j]; j = trie_[j].back_link) back_offs.push(j); + while (!back_offs.empty()) { + int j = back_offs.top(); + back_offs.pop(); + WeightBackLink &node = trie_[j]; + node.weight = Times(node.weight, trie_[node.back_link].weight); + node.final_weight = + Times(node.final_weight, trie_[node.back_link].final_weight); + visited[j] = true; + } + } +} + +template +bool FeatureGroupBuilder::TrieDfs( + const Topology &topology, int cur, int target, + std::vector *path) const { + if (cur == target) return true; + const typename Topology::NextMap &children = topology.ChildrenOf(cur); + for (typename Topology::NextMap::const_iterator eit = children.begin(); + eit != children.end(); ++eit) { + const std::pair &edge = *eit; + path->push_back(edge.first); + if (TrieDfs(topology, edge.second, target, path)) return true; + path->pop_back(); + } + return false; +} + +template +string FeatureGroupBuilder::TriePath(int node, + const Topology &topology) const { + std::vector labels; + TrieDfs(topology, topology.Root(), node, &labels); + bool first = true; + std::ostringstream strm; + for (typename std::vector::const_iterator it = + labels.begin(); + it != labels.end(); ++it) { + InputOutputLabel i = *it; + if (first) + first = false; + else + strm << ", "; + strm << "(" << TranslateLabel(i.input, fsyms_) << ", " + << TranslateLabel(i.output, osyms_) << ")"; + } + return strm.str(); +} + +inline string TranslateLabel(int64 label, const SymbolTable *syms) { + string ret; + if (syms != nullptr) ret += syms->Find(label); + if (ret.empty()) { + std::ostringstream strm; + strm << '<' << label << '>'; + ret = strm.str(); + } + return ret; +} + +template +string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms) { + if (begin == end) return ""; + std::ostringstream strm; + bool first = true; + for (Iterator it = begin; it != end; ++it) { + if (first) + first = false; + else + strm << '|'; + strm << TranslateLabel(*it, syms); + } + return strm.str(); +} + +template +string JoinLabels(const std::vector::kStartOfSentence; + } else if (left && !right) { + // Can only be end + (*sequence)[i] = LinearFstData::kEndOfSentence; + } else if (!left && right) { + // Can only be start + (*sequence)[i] = LinearFstData::kStartOfSentence; + } else { + // !left && !right; can't really tell + ++unresolved; + } + } + return unresolved; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h new file mode 100644 index 00000000..3b39c29d --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h @@ -0,0 +1,526 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Data structures for storing and looking up the actual feature weights. + +#ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ +#define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace fst { + +// Forward declarations +template +class LinearFstDataBuilder; +template +class FeatureGroup; + +// Immutable data storage of the feature weights in a linear +// model. Produces state tuples that represent internal states of a +// LinearTaggerFst. Object of this class can only be constructed via +// either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()` +// and usually used as refcount'd object shared across mutiple +// `LinearTaggerFst` copies. +// +// TODO(wuke): more efficient trie implementation +template +class LinearFstData { + public: + friend class LinearFstDataBuilder; // For builder access + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Sentence boundary labels. Both of them are negative labels other + // than `kNoLabel`. + static const Label kStartOfSentence; + static const Label kEndOfSentence; + + // Constructs empty data; for non-trivial ways of construction see + // `Read()` and `LinearFstDataBuilder`. + LinearFstData() + : max_future_size_(0), max_input_label_(1), input_attribs_(1) {} + + // Appends the state tuple of the start state to `output`, where + // each tuple holds the node ids of a trie for each feature group. + void EncodeStartState(std::vector *Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + + private: + // Offsets in `output_pool_` + struct InputAttribute { + size_t output_begin, output_length; + + std::istream &Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + }; + + // Mapping from input label to per-group feature label + class GroupFeatureMap; + + // Translates the input label into input feature label of group + // `group`; returns `kNoLabel` when there is no feature for that + // group. + Label FindFeature(size_t group, Label word) const; + + size_t max_future_size_; + Label max_input_label_; + std::vector>> groups_; + std::vector input_attribs_; + std::vector::kStartOfSentence = -3; +template +const typename A::Label LinearFstData::kEndOfSentence = -2; + +template +template +void LinearFstData::TakeTransition(Iterator buffer_end, + Iterator trie_state_begin, + Iterator trie_state_end, Label ilabel, + Label olabel, std::vector::GroupTransition(int group_id, + int trie_state, + Label ilabel, Label olabel, + Weight *weight) const { + Label group_ilabel = FindFeature(group_id, ilabel); + return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight); +} + +template +template +inline typename A::Weight LinearFstData::FinalWeight( + Iterator trie_state_begin, Iterator trie_state_end) const { + DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size()); + size_t group_id = 0; + Weight accum = Weight::One(); + for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) + accum = Times(accum, GroupFinalWeight(group_id, *it)); + return accum; +} + +template +inline std::pair::const_iterator, + typename std::vector::const_iterator> +LinearFstData::PossibleOutputLabels(Label word) const { + const InputAttribute &attrib = input_attribs_[word]; + if (attrib.output_length == 0) + return std::make_pair(output_set_.begin(), output_set_.end()); + else + return std::make_pair( + output_pool_.begin() + attrib.output_begin, + output_pool_.begin() + attrib.output_begin + attrib.output_length); +} + +template +inline LinearFstData *LinearFstData::Read(std::istream &strm) { // NOLINT + std::unique_ptr> data(new LinearFstData()); + ReadType(strm, &(data->max_future_size_)); + ReadType(strm, &(data->max_input_label_)); + // Feature groups + size_t num_groups = 0; + ReadType(strm, &num_groups); + data->groups_.resize(num_groups); + for (size_t i = 0; i < num_groups; ++i) + data->groups_[i].reset(FeatureGroup::Read(strm)); + // Other data + ReadType(strm, &(data->input_attribs_)); + ReadType(strm, &(data->output_pool_)); + ReadType(strm, &(data->output_set_)); + ReadType(strm, &(data->group_feat_map_)); + if (strm) { + return data.release(); + } else { + return nullptr; + } +} + +template +inline std::ostream &LinearFstData::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, max_future_size_); + WriteType(strm, max_input_label_); + // Feature groups + WriteType(strm, groups_.size()); + for (size_t i = 0; i < groups_.size(); ++i) { + groups_[i]->Write(strm); + } + // Other data + WriteType(strm, input_attribs_); + WriteType(strm, output_pool_); + WriteType(strm, output_set_); + WriteType(strm, group_feat_map_); + return strm; +} + +template +typename A::Label LinearFstData::FindFeature(size_t group, + Label word) const { + DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence); + if (word == kStartOfSentence || word == kEndOfSentence) + return word; + else + return group_feat_map_.Find(group, word); +} + +template +inline std::istream &LinearFstData::InputAttribute::Read( + std::istream &strm) { // NOLINT + ReadType(strm, &output_begin); + ReadType(strm, &output_length); + return strm; +} + +template +inline std::ostream &LinearFstData::InputAttribute::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, output_begin); + WriteType(strm, output_length); + return strm; +} + +// Forward declaration +template +class FeatureGroupBuilder; + +// An immutable grouping of features with similar context shape. Like +// `LinearFstData`, this can only be constructed via `Read()` or +// via its builder. +// +// Internally it uses a trie to store all feature n-grams and their +// weights. The label of a trie edge is a pair (feat, olabel) of +// labels. They can be either positive (ordinary label), `kNoLabel`, +// `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means +// matching anything, with one exception: from the root of the trie, +// there is a special (kNoLabel, kNoLabel) that leads to the implicit +// start-of-sentence state. This edge is never actually matched +// (`FindFirstMatch()` ensures this). +template +class FeatureGroup { + public: + friend class FeatureGroupBuilder; // for builder access + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + int Start() const { return start_; } + + // Finds destination node from `cur` by consuming `ilabel` and + // `olabel`. The transition weight is multiplied onto `weight`. + int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const; + + // Returns the final weight of the current trie state. Only valid if + // the state is already known to be part of a final state (see + // `LinearFstData<>::CanBeFinal()`). + Weight FinalWeight(int trie_state) const { + return trie_[trie_state].final_weight; + } + + static FeatureGroup *Read(std::istream &strm) { // NOLINT + size_t delay; + ReadType(strm, &delay); + int start; + ReadType(strm, &start); + Trie trie; + ReadType(strm, &trie); + std::unique_ptr> ret(new FeatureGroup(delay, start)); + ret->trie_.swap(trie); + ReadType(strm, &ret->next_state_); + if (strm) { + return ret.release(); + } else { + return nullptr; + } + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, delay_); + WriteType(strm, start_); + WriteType(strm, trie_); + WriteType(strm, next_state_); + return strm; + } + + size_t Delay() const { return delay_; } + + string Stats() const; + + private: + // Label along the arcs on the trie. `kNoLabel` means anything + // (non-negative label) can match; both sides holding `kNoLabel` + // is not allow; otherwise the label is > 0 (enforced by + // `LinearFstDataBuilder::AddWeight()`). + struct InputOutputLabel; + struct InputOutputLabelHash; + + // Data to be stored on the trie + struct WeightBackLink { + int back_link; + Weight weight, final_weight; + + WeightBackLink() + : back_link(kNoTrieNodeId), + weight(Weight::One()), + final_weight(Weight::One()) {} + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &back_link); + ReadType(strm, &weight); + ReadType(strm, &final_weight); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, back_link); + WriteType(strm, weight); + WriteType(strm, final_weight); + return strm; + } + }; + + typedef FlatTrieTopology Topology; + typedef MutableTrie Trie; + + explicit FeatureGroup(size_t delay, int start) + : delay_(delay), start_(start) {} + + // Finds the first node with an arc with `label` following the + // back-off chain of `parent`. Returns the node index or + // `kNoTrieNodeId` when not found. + int FindFirstMatch(InputOutputLabel label, int parent) const; + + size_t delay_; + int start_; + Trie trie_; + // Where to go after hitting this state. When we reach a state with + // no child and with no additional final weight (i.e. its final + // weight is the same as its back-off), we can immediately go to its + // back-off state. + std::vector next_state_; + + FeatureGroup(const FeatureGroup &) = delete; + FeatureGroup &operator=(const FeatureGroup &) = delete; +}; + +template +struct FeatureGroup::InputOutputLabel { + Label input, output; + + InputOutputLabel(Label i = kNoLabel, Label o = kNoLabel) + : input(i), output(o) {} + + bool operator==(InputOutputLabel that) const { + return input == that.input && output == that.output; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &input); + ReadType(strm, &output); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, input); + WriteType(strm, output); + return strm; + } +}; + +template +struct FeatureGroup::InputOutputLabelHash { + size_t operator()(InputOutputLabel label) const { + return static_cast(label.input * 7853 + label.output); + } +}; + +template +int FeatureGroup::Walk(int cur, Label ilabel, Label olabel, + Weight *weight) const { + // Note: user of this method need to ensure `ilabel` and `olabel` + // are valid (e.g. see DCHECKs in + // `LinearFstData<>::TakeTransition()` and + // `LinearFstData<>::FindFeature()`). + int next; + if (ilabel == LinearFstData::kStartOfSentence) { + // An observed start-of-sentence only occurs in the beginning of + // the input, when this feature group is delayed (i.e. there is + // another feature group with a larger future size). The actual + // input hasn't arrived so stay at the start state. + DCHECK_EQ(cur, start_); + next = start_; + } else { + // First, try exact match + next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur); + // Then try with don't cares + if (next == kNoTrieNodeId) + next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur); + if (next == kNoTrieNodeId) + next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur); + // All failed, go to empty context + if (next == kNoTrieNodeId) next = trie_.Root(); + *weight = Times(*weight, trie_[next].weight); + next = next_state_[next]; + } + return next; +} + +template +inline int FeatureGroup::FindFirstMatch(InputOutputLabel label, + int parent) const { + if (label.input == kNoLabel && label.output == kNoLabel) + return kNoTrieNodeId; // very important; see class doc. + for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) { + int next = trie_.Find(parent, label); + if (next != kNoTrieNodeId) return next; + } + return kNoTrieNodeId; +} + +template +inline string FeatureGroup::Stats() const { + std::ostringstream strm; + int num_states = 2; + for (int i = 2; i < next_state_.size(); ++i) + num_states += i == next_state_[i]; + strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)"; + return strm.str(); +} + +template +class LinearFstData::GroupFeatureMap { + public: + GroupFeatureMap() {} + + void Init(size_t num_groups, size_t num_words) { + num_groups_ = num_groups; + pool_.clear(); + pool_.resize(num_groups * num_words, kNoLabel); + } + + Label Find(size_t group_id, Label ilabel) const { + return pool_[IndexOf(group_id, ilabel)]; + } + + bool Set(size_t group_id, Label ilabel, Label feat) { + size_t i = IndexOf(group_id, ilabel); + if (pool_[i] != kNoLabel && pool_[i] != feat) { + FSTERROR() << "Feature group " << group_id + << " already has feature for word " << ilabel; + return false; + } + pool_[i] = feat; + return true; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &num_groups_); + ReadType(strm, &pool_); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, num_groups_); + WriteType(strm, pool_); + return strm; + } + + private: + size_t IndexOf(size_t group_id, Label ilabel) const { + return ilabel * num_groups_ + group_id; + } + + size_t num_groups_; + // `pool_[ilabel * num_groups_ + group_id]` is the feature active + // for group `group_id` with input `ilabel` + std::vector { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename Collection::SetIterator NGramIterator; + + // Constructs an empty FST by default. + LinearTaggerFstImpl() + : CacheImpl(CacheOptions()), + data_(std::make_shared>()), + delay_(0) { + SetType("linear-tagger"); + } + + // Constructs the FST with given data storage and symbol + // tables. + // + // TODO(wuke): when there is no constraint on output we can delay + // less than `data->MaxFutureSize` positions. + LinearTaggerFstImpl(const LinearFstData *data, const SymbolTable *isyms, + const SymbolTable *osyms, CacheOptions opts) + : CacheImpl(opts), data_(data), delay_(data->MaxFutureSize()) { + SetType("linear-tagger"); + SetProperties(kILabelSorted, kFstProperties); + SetInputSymbols(isyms); + SetOutputSymbols(osyms); + ReserveStubSpace(); + } + + // Copy by sharing the underlying data storage. + LinearTaggerFstImpl(const LinearTaggerFstImpl &impl) + : CacheImpl(impl), data_(impl.data_), delay_(impl.delay_) { + SetType("linear-tagger"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + ReserveStubSpace(); + } + + StateId Start() { + if (!HasStart()) { + StateId start = FindStartState(); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + state_stub_.clear(); + FillState(s, &state_stub_); + if (CanBeFinal(state_stub_)) + SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_), + InternalEnd(state_stub_))); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new + // destination states as needed. + void Expand(StateId s); + + // Appends to `arcs` all out-going arcs from state `s` that matches `label` as + // the input label. + void MatchInput(StateId s, Label ilabel, std::vector *arcs); + + static LinearTaggerFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + bool Write(std::ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader header; + header.SetStart(kNoStateId); + WriteHeader(strm, opts, kFileVersion, &header); + data_->Write(strm); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + private: + static const int kMinFileVersion; + static const int kFileVersion; + + // A collection of functions to access parts of the state tuple. A + // state tuple is a vector of `Label`s with two parts: + // [buffer] [internal]. + // + // - [buffer] is a buffer of observed input labels with length + // `delay_`. `LinearFstData::kStartOfSentence` + // (resp. `LinearFstData::kEndOfSentence`) are used as + // paddings when the buffer has fewer than `delay_` elements, which + // can only appear as the prefix (resp. suffix) of the buffer. + // + // - [internal] is the internal state tuple for `LinearFstData` + typename std::vector::kStartOfSentence); + // Append internal states + data_->EncodeStartState(&state_stub_); + return FindState(state_stub_); + } + + // Tests whether the buffer in `(begin, end)` is empty. + bool IsEmptyBuffer(typename std::vector::kEndOfSentence => + // buffer[i+x] == LinearFstData::kEndOfSentence + // - buffer[i] == LinearFstData::kStartOfSentence => + // buffer[i-x] == LinearFstData::kStartOfSentence + return delay_ == 0 || *(end - 1) == LinearFstData::kStartOfSentence || + *begin == LinearFstData::kEndOfSentence; + } + + // Tests whether the given state tuple can be a final state. A state + // is final iff there is no observed input in the buffer. + bool CanBeFinal(const std::vector::kMinFileVersion = 1; + +template +const int LinearTaggerFstImpl::kFileVersion = 1; + +template +inline typename A::Label LinearTaggerFstImpl::ShiftBuffer( + const std::vector::kEndOfSentence); + if (delay_ == 0) { + DCHECK_GT(ilabel, 0); + return ilabel; + } else { + (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel; + return *BufferBegin(state); + } +} + +template +inline A LinearTaggerFstImpl::MakeArc(const std::vector::kEndOfSentence); + DCHECK(olabel > 0 || olabel == LinearFstData::kStartOfSentence); + Weight weight(Weight::One()); + data_->TakeTransition(BufferEnd(state), InternalBegin(state), + InternalEnd(state), ilabel, olabel, next_stub_, + &weight); + StateId nextstate = FindState(*next_stub_); + // Restore `next_stub_` to its size before the call + next_stub_->resize(delay_); + // In the actual arc, we use epsilons instead of boundaries. + return A(ilabel == LinearFstData::kEndOfSentence ? 0 : ilabel, + olabel == LinearFstData::kStartOfSentence ? 0 : olabel, weight, + nextstate); +} + +template +inline void LinearTaggerFstImpl::ExpandArcs(StateId s, + const std::vector::kStartOfSentence) { + // This happens when input is shorter than `delay_`. + PushArc(s, MakeArc(state, ilabel, LinearFstData::kStartOfSentence, + next_stub_)); + } else { + std::pair::const_iterator, + typename std::vector::const_iterator> range = + data_->PossibleOutputLabels(obs_ilabel); + for (typename std::vector::const_iterator it = + range.first; + it != range.second; ++it) + PushArc(s, MakeArc(state, ilabel, *it, next_stub_)); + } +} + +// TODO(wuke): this has much in duplicate with `ExpandArcs()` +template +inline void LinearTaggerFstImpl::AppendArcs(StateId /*s*/, + const std::vector::kStartOfSentence) { + // This happens when input is shorter than `delay_`. + arcs->push_back( + MakeArc(state, ilabel, LinearFstData::kStartOfSentence, next_stub_)); + } else { + std::pair::const_iterator, + typename std::vector::const_iterator> range = + data_->PossibleOutputLabels(obs_ilabel); + for (typename std::vector::const_iterator it = + range.first; + it != range.second; ++it) + arcs->push_back(MakeArc(state, ilabel, *it, next_stub_)); + } +} + +template +void LinearTaggerFstImpl::Expand(StateId s) { + VLOG(3) << "Expand " << s; + state_stub_.clear(); + FillState(s, &state_stub_); + + // Precompute the first `delay_ - 1` elements in the buffer of + // next states, which are identical for different input/output. + next_stub_.clear(); + next_stub_.resize(delay_); + if (delay_ > 0) + std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), + next_stub_.begin()); + + // Epsilon transition for flushing out the next observed input + if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) + ExpandArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_); + + // Non-epsilon input when we haven't flushed + if (delay_ == 0 || + *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) + for (Label ilabel = data_->MinInputLabel(); + ilabel <= data_->MaxInputLabel(); ++ilabel) + ExpandArcs(s, state_stub_, ilabel, &next_stub_); + + SetArcs(s); +} + +template +void LinearTaggerFstImpl::MatchInput(StateId s, Label ilabel, + std::vector *arcs) { + state_stub_.clear(); + FillState(s, &state_stub_); + + // Precompute the first `delay_ - 1` elements in the buffer of + // next states, which are identical for different input/output. + next_stub_.clear(); + next_stub_.resize(delay_); + if (delay_ > 0) + std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), + next_stub_.begin()); + + if (ilabel == 0) { + // Epsilon transition for flushing out the next observed input + if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) + AppendArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_, + arcs); + } else { + // Non-epsilon input when we haven't flushed + if (delay_ == 0 || + *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) + AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs); + } +} + +template +inline LinearTaggerFstImpl *LinearTaggerFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { // NOLINT + std::unique_ptr> impl(new LinearTaggerFstImpl()); + FstHeader header; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { + return nullptr; + } + impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); + if (!impl->data_) { + return nullptr; + } + impl->delay_ = impl->data_->MaxFutureSize(); + impl->ReserveStubSpace(); + return impl.release(); +} + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class LinearTaggerFst : public ImplToFst> { + public: + friend class ArcIterator>; + friend class StateIterator>; + friend class LinearFstMatcherTpl>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef DefaultCacheStore Store; + typedef typename Store::State State; + using Impl = internal::LinearTaggerFstImpl; + + LinearTaggerFst() : ImplToFst(std::make_shared()) {} + + explicit LinearTaggerFst(LinearFstData *data, + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + CacheOptions opts = CacheOptions()) + : ImplToFst(std::make_shared(data, isyms, osyms, opts)) {} + + explicit LinearTaggerFst(const Fst &fst) + : ImplToFst(std::make_shared()) { + LOG(FATAL) << "LinearTaggerFst: no constructor from arbitrary FST."; + } + + // See Fst<>::Copy() for doc. + LinearTaggerFst(const LinearTaggerFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this LinearTaggerFst. See Fst<>::Copy() for further doc. + LinearTaggerFst *Copy(bool safe = false) const override { + return new LinearTaggerFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new LinearFstMatcherTpl>(this, match_type); + } + + static LinearTaggerFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + static LinearTaggerFst *Read(std::istream &in, // NOLINT + const FstReadOptions &opts) { + auto *impl = Impl::Read(in, opts); + return impl ? new LinearTaggerFst(std::shared_ptr(impl)) : nullptr; + } + + bool Write(const string &filename) const override { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Write: Can't open file: " << filename; + return false; + } + return Write(strm, FstWriteOptions(filename)); + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit LinearTaggerFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + void operator=(const LinearTaggerFst &fst) = delete; +}; + +// Specialization for LinearTaggerFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const LinearTaggerFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for LinearTaggerFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const LinearTaggerFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void LinearTaggerFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +namespace internal { + +// Implementation class for on-the-fly generated LinearClassifierFst with +// special optimization in matching. +template +class LinearClassifierFstImpl : public CacheImpl { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename Collection::SetIterator NGramIterator; + + // Constructs an empty FST by default. + LinearClassifierFstImpl() + : CacheImpl(CacheOptions()), + data_(std::make_shared>()) { + SetType("linear-classifier"); + num_classes_ = 0; + num_groups_ = 0; + } + + // Constructs the FST with given data storage, number of classes and + // symbol tables. + LinearClassifierFstImpl(const LinearFstData *data, size_t num_classes, + const SymbolTable *isyms, const SymbolTable *osyms, + CacheOptions opts) + : CacheImpl(opts), + data_(data), + num_classes_(num_classes), + num_groups_(data_->NumGroups() / num_classes_) { + SetType("linear-classifier"); + SetProperties(kILabelSorted, kFstProperties); + SetInputSymbols(isyms); + SetOutputSymbols(osyms); + ReserveStubSpace(); + } + + // Copy by sharing the underlying data storage. + LinearClassifierFstImpl(const LinearClassifierFstImpl &impl) + : CacheImpl(impl), + data_(impl.data_), + num_classes_(impl.num_classes_), + num_groups_(impl.num_groups_) { + SetType("linear-classifier"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + ReserveStubSpace(); + } + + StateId Start() { + if (!HasStart()) { + StateId start = FindStartState(); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + state_stub_.clear(); + FillState(s, &state_stub_); + SetFinal(s, FinalWeight(state_stub_)); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new + // destination states as needed. + void Expand(StateId s); + + // Appends to `arcs` all out-going arcs from state `s` that matches + // `label` as the input label. + void MatchInput(StateId s, Label ilabel, std::vector *arcs); + + static LinearClassifierFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader header; + header.SetStart(kNoStateId); + WriteHeader(strm, opts, kFileVersion, &header); + data_->Write(strm); + WriteType(strm, num_classes_); + if (!strm) { + LOG(ERROR) << "LinearClassifierFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + private: + static const int kMinFileVersion; + static const int kFileVersion; + + // A collection of functions to access parts of the state tuple. A + // state tuple is a vector of `Label`s with two parts: + // [prediction] [internal]. + // + // - [prediction] is a single label of the predicted class. A state + // must have a positive class label, unless it is the start state. + // + // - [internal] is the internal state tuple for `LinearFstData` of + // the given class; or kNoTrieNodeId's if in start state. + Label &Prediction(std::vector &) = delete; +}; + +template +const int LinearClassifierFstImpl::kMinFileVersion = 0; + +template +const int LinearClassifierFstImpl::kFileVersion = 0; + +template +void LinearClassifierFstImpl::Expand(StateId s) { + VLOG(3) << "Expand " << s; + state_stub_.clear(); + FillState(s, &state_stub_); + next_stub_.clear(); + next_stub_.resize(1 + num_groups_); + + if (IsStartState(state_stub_)) { + // Make prediction + for (Label pred = 1; pred <= num_classes_; ++pred) { + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); + PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_))); + } + } else { + Label pred = Prediction(state_stub_); + DCHECK_GT(pred, 0); + DCHECK_LE(pred, num_classes_); + for (Label ilabel = data_->MinInputLabel(); + ilabel <= data_->MaxInputLabel(); ++ilabel) { + Prediction(next_stub_) = pred; + Weight weight = Weight::One(); + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = + data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i), + ilabel, pred, &weight); + PushArc(s, A(ilabel, 0, weight, FindState(next_stub_))); + } + } + + SetArcs(s); +} + +template +void LinearClassifierFstImpl::MatchInput(StateId s, Label ilabel, + std::vector *arcs) { + state_stub_.clear(); + FillState(s, &state_stub_); + next_stub_.clear(); + next_stub_.resize(1 + num_groups_); + + if (IsStartState(state_stub_)) { + // Make prediction if `ilabel` is epsilon. + if (ilabel == 0) { + for (Label pred = 1; pred <= num_classes_; ++pred) { + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); + arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_))); + } + } + } else if (ilabel != 0) { + Label pred = Prediction(state_stub_); + Weight weight = Weight::One(); + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupTransition( + GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight); + arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_))); + } +} + +template +inline LinearClassifierFstImpl *LinearClassifierFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { + std::unique_ptr> impl( + new LinearClassifierFstImpl()); + FstHeader header; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { + return nullptr; + } + impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); + if (!impl->data_) { + return nullptr; + } + ReadType(strm, &impl->num_classes_); + if (!strm) { + return nullptr; + } + impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_; + if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) { + FSTERROR() << "Total number of feature groups is not a multiple of the " + "number of classes: num groups = " + << impl->data_->NumGroups() + << ", num classes = " << impl->num_classes_; + return nullptr; + } + impl->ReserveStubSpace(); + return impl.release(); +} + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class LinearClassifierFst + : public ImplToFst> { + public: + friend class ArcIterator>; + friend class StateIterator>; + friend class LinearFstMatcherTpl>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef DefaultCacheStore Store; + typedef typename Store::State State; + using Impl = internal::LinearClassifierFstImpl; + + LinearClassifierFst() : ImplToFst(std::make_shared()) {} + + explicit LinearClassifierFst(LinearFstData *data, size_t num_classes, + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + CacheOptions opts = CacheOptions()) + : ImplToFst( + std::make_shared(data, num_classes, isyms, osyms, opts)) {} + + explicit LinearClassifierFst(const Fst &fst) + : ImplToFst(std::make_shared()) { + LOG(FATAL) << "LinearClassifierFst: no constructor from arbitrary FST."; + } + + // See Fst<>::Copy() for doc. + LinearClassifierFst(const LinearClassifierFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this LinearClassifierFst. See Fst<>::Copy() for further doc. + LinearClassifierFst *Copy(bool safe = false) const override { + return new LinearClassifierFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new LinearFstMatcherTpl>(this, match_type); + } + + static LinearClassifierFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearClassifierFst::Read: Can't open file: " + << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + static LinearClassifierFst *Read(std::istream &in, + const FstReadOptions &opts) { + auto *impl = Impl::Read(in, opts); + return impl ? new LinearClassifierFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(const string &filename) const override { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ProdLmFst::Write: Can't open file: " << filename; + return false; + } + return Write(strm, FstWriteOptions(filename)); + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit LinearClassifierFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + void operator=(const LinearClassifierFst &fst) = delete; +}; + +// Specialization for LinearClassifierFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const LinearClassifierFst &fst) + : CacheStateIterator>(fst, + fst.GetMutableImpl()) {} +}; + +// Specialization for LinearClassifierFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const LinearClassifierFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void LinearClassifierFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Specialized Matcher for LinearFsts. This matcher only supports +// matching from the input side. This is intentional because comparing +// the scores of different input sequences with the same output +// sequence is meaningless in a discriminative model. +template +class LinearFstMatcherTpl : public MatcherBase { + public: + typedef typename F::Arc Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + typedef F FST; + + // This makes a copy of the FST. + LinearFstMatcherTpl(const FST &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + match_type_(match_type), + s_(kNoStateId), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + cur_arc_(0), + error_(false) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_OUTPUT: + case MATCH_NONE: + break; + default: + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This doesn't copy the FST. + LinearFstMatcherTpl(const FST *fst, MatchType match_type) + : fst_(*fst), + match_type_(match_type), + s_(kNoStateId), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + cur_arc_(0), + error_(false) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_OUTPUT: + case MATCH_NONE: + break; + default: + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + LinearFstMatcherTpl(const LinearFstMatcherTpl &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + match_type_(matcher.match_type_), + s_(kNoStateId), + current_loop_(false), + loop_(matcher.loop_), + cur_arc_(0), + error_(matcher.error_) {} + + LinearFstMatcherTpl *Copy(bool safe = false) const override { + return new LinearFstMatcherTpl(*this, safe); + } + + MatchType Type(bool /*test*/) const override { + // `MATCH_INPUT` is the only valid type + return match_type_ == MATCH_INPUT ? match_type_ : MATCH_NONE; + } + + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + // `MATCH_INPUT` is the only valid type + if (match_type_ != MATCH_INPUT) { + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + error_ = true; + } + loop_.nextstate = s; + } + + bool Find(Label label) final { + if (error_) { + current_loop_ = false; + return false; + } + current_loop_ = label == 0; + if (label == kNoLabel) label = 0; + arcs_.clear(); + cur_arc_ = 0; + fst_.GetMutableImpl()->MatchInput(s_, label, &arcs_); + return current_loop_ || !arcs_.empty(); + } + + bool Done() const final { + return !(current_loop_ || cur_arc_ < arcs_.size()); + } + + const Arc &Value() const final { + return current_loop_ ? loop_ : arcs_[cur_arc_]; + } + + void Next() final { + if (current_loop_) + current_loop_ = false; + else + ++cur_arc_; + } + + ssize_t Priority(StateId s) final { return kRequirePriority; } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { + if (error_) props |= kError; + return props; + } + + uint32 Flags() const override { return kRequireMatch; } + + private: + std::unique_ptr owned_fst_; + const FST &fst_; + MatchType match_type_; // Type of match to perform. + StateId s_; // Current state. + bool current_loop_; // Current arc is the implicit loop. + Arc loop_; // For non-consuming symbols. + // All out-going arcs matching the label in last Find() call. + std::vector arcs_; + size_t cur_arc_; // Index to the arc that `Value()` should return. + bool error_; // Error encountered. +}; + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linearscript.h b/projects/llm_framework/include/fst/extensions/linear/linearscript.h new file mode 100644 index 00000000..54106d20 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linearscript.h @@ -0,0 +1,391 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ +#define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +DECLARE_string(delimiter); +DECLARE_string(empty_symbol); +DECLARE_string(start_symbol); +DECLARE_string(end_symbol); +DECLARE_bool(classifier); + +namespace fst { +namespace script { +typedef std::tuple + LinearCompileArgs; + +bool ValidateDelimiter(); +bool ValidateEmptySymbol(); + +// Returns the proper label given the symbol. For symbols other than +// `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol +// table to decide the label. Depending on whether +// `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it +// either returns `kNoLabel` for later processing or decides the label +// right away. +template +inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) { + if (str == FLAGS_start_symbol) + return str == FLAGS_end_symbol ? kNoLabel + : LinearFstData::kStartOfSentence; + else if (str == FLAGS_end_symbol) + return LinearFstData::kEndOfSentence; + else + return syms->AddSymbol(str); +} + +// Splits `str` with `delim` as the delimiter and stores the labels in +// `output`. +template +void SplitAndPush(const string &str, const char delim, SymbolTable *syms, + std::vector *output) { + if (str == FLAGS_empty_symbol) return; + std::istringstream strm(str); + string buf; + while (std::getline(strm, buf, delim)) + output->push_back(LookUp(buf, syms)); +} + +// Like `std::replace_copy` but returns the number of modifications +template +size_t ReplaceCopy(InputIterator first, InputIterator last, + OutputIterator result, const T &old_value, + const T &new_value) { + size_t changes = 0; + while (first != last) { + if (*first == old_value) { + *result = new_value; + ++changes; + } else { + *result = *first; + } + ++first; + ++result; + } + return changes; +} + +template +bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT + SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, + typename Arc::Label *word, + std::vector *feature_labels, + std::vector *possible_labels, + size_t *num_line); + +template +bool GetModelRecord(const string &model, std::istream &strm, // NOLINT + SymbolTable *fsyms, SymbolTable *osyms, + std::vector *input_labels, + std::vector *output_labels, + typename Arc::Weight *weight, size_t *num_line); + +// Reads in vocabulary file. Each line is in the following format +// +// word features [ possible output ] +// +// where features and possible output are `FLAGS_delimiter`-delimited lists of +// tokens +template +void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms, + SymbolTable *osyms, LinearFstDataBuilder *builder) { + std::ifstream in(vocab); + if (!in) LOG(FATAL) << "Can't open file: " << vocab; + size_t num_line = 0, num_added = 0; + std::vector fields; + std::vector feature_labels, possible_labels; + typename Arc::Label word; + while (GetVocabRecord(vocab, in, isyms, fsyms, osyms, &word, + &feature_labels, &possible_labels, &num_line)) { + if (word == kNoLabel) { + LOG(WARNING) << "Ignored: boundary word: " << fields[0]; + continue; + } + if (possible_labels.empty()) + num_added += builder->AddWord(word, feature_labels); + else + num_added += builder->AddWord(word, feature_labels, possible_labels); + } + VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from " + << vocab; +} + +template +void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms, + SymbolTable *osyms, + LinearClassifierFstDataBuilder *builder) { + std::ifstream in(vocab); + if (!in) LOG(FATAL) << "Can't open file: " << vocab; + size_t num_line = 0, num_added = 0; + std::vector fields; + std::vector feature_labels, possible_labels; + typename Arc::Label word; + while (GetVocabRecord(vocab, in, isyms, fsyms, osyms, &word, + &feature_labels, &possible_labels, &num_line)) { + if (!possible_labels.empty()) + LOG(FATAL) + << "Classifier vocabulary should not have possible output constraint"; + if (word == kNoLabel) { + LOG(WARNING) << "Ignored: boundary word: " << fields[0]; + continue; + } + num_added += builder->AddWord(word, feature_labels); + } + VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from " + << vocab; +} + +// Reads in model file. The first line is an integer designating the +// size of future window in the input sequences. After this, each line +// is in the following format +// +// input sequence output sequence weight +// +// input sequence is a `FLAGS_delimiter`-delimited sequence of feature +// labels (see `AddVocab()`) . output sequence is a +// `FLAGS_delimiter`-delimited sequence of output labels where the +// last label is the output of the feature position before the history +// boundary. +template +void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms, + LinearFstDataBuilder *builder) { + std::ifstream in(model); + if (!in) LOG(FATAL) << "Can't open file: " << model; + string line; + std::getline(in, line); + if (!in) LOG(FATAL) << "Empty file: " << model; + size_t future_size; + { + std::istringstream strm(line); + strm >> future_size; + if (!strm) LOG(FATAL) << "Can't read future size: " << model; + } + size_t num_line = 1, num_added = 0; + const int group = builder->AddGroup(future_size); + VLOG(1) << "Group " << group << ": from " << model << "; future size is " + << future_size << "."; + // Add the rest of lines as a single feature group + std::vector fields; + std::vector input_labels, output_labels; + typename Arc::Weight weight; + while (GetModelRecord(model, in, fsyms, osyms, &input_labels, + &output_labels, &weight, &num_line)) { + if (output_labels.empty()) + LOG(FATAL) << "Empty output sequence in source " << model << ", line " + << num_line; + + const typename Arc::Label marks[] = {LinearFstData::kStartOfSentence, + LinearFstData::kEndOfSentence}; + + std::vector copy_input(input_labels.size()), + copy_output(output_labels.size()); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + size_t num_input_changes = + ReplaceCopy(input_labels.begin(), input_labels.end(), + copy_input.begin(), kNoLabel, marks[i]); + size_t num_output_changes = + ReplaceCopy(output_labels.begin(), output_labels.end(), + copy_output.begin(), kNoLabel, marks[j]); + if ((num_input_changes > 0 || i == 0) && + (num_output_changes > 0 || j == 0)) + num_added += + builder->AddWeight(group, copy_input, copy_output, weight); + } + } + } + VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in " + << num_line << " lines."; +} + +template +void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms, + LinearClassifierFstDataBuilder *builder) { + std::ifstream in(model); + if (!in) LOG(FATAL) << "Can't open file: " << model; + string line; + std::getline(in, line); + if (!in) LOG(FATAL) << "Empty file: " << model; + size_t future_size; + { + std::istringstream strm(line); + strm >> future_size; + if (!strm) LOG(FATAL) << "Can't read future size: " << model; + } + if (future_size != 0) + LOG(FATAL) << "Classifier model must have future size = 0; got " + << future_size << " from " << model; + size_t num_line = 1, num_added = 0; + const int group = builder->AddGroup(); + VLOG(1) << "Group " << group << ": from " << model << "; future size is " + << future_size << "."; + // Add the rest of lines as a single feature group + std::vector fields; + std::vector input_labels, output_labels; + typename Arc::Weight weight; + while (GetModelRecord(model, in, fsyms, osyms, &input_labels, + &output_labels, &weight, &num_line)) { + if (output_labels.size() != 1) + LOG(FATAL) << "Output not a single label in source " << model << ", line " + << num_line; + + const typename Arc::Label marks[] = {LinearFstData::kStartOfSentence, + LinearFstData::kEndOfSentence}; + + typename Arc::Label pred = output_labels[0]; + + std::vector copy_input(input_labels.size()); + for (int i = 0; i < 2; ++i) { + size_t num_input_changes = + ReplaceCopy(input_labels.begin(), input_labels.end(), + copy_input.begin(), kNoLabel, marks[i]); + if (num_input_changes > 0 || i == 0) + num_added += builder->AddWeight(group, copy_input, pred, weight); + } + } + VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in " + << num_line << " lines."; +} + +void SplitByWhitespace(const string &str, std::vector *out); +int ScanNumClasses(char **models, int models_length); + +template +void LinearCompileTpl(LinearCompileArgs *args) { + const string &epsilon_symbol = std::get<0>(*args); + const string &unknown_symbol = std::get<1>(*args); + const string &vocab = std::get<2>(*args); + char **models = std::get<3>(*args); + const int models_length = std::get<4>(*args); + const string &out = std::get<5>(*args); + const string &save_isymbols = std::get<6>(*args); + const string &save_fsymbols = std::get<7>(*args); + const string &save_osymbols = std::get<8>(*args); + + SymbolTable isyms, // input (e.g. word tokens) + osyms, // output (e.g. tags) + fsyms; // feature (e.g. word identity, suffix, etc.) + isyms.AddSymbol(epsilon_symbol); + osyms.AddSymbol(epsilon_symbol); + fsyms.AddSymbol(epsilon_symbol); + isyms.AddSymbol(unknown_symbol); + + VLOG(1) << "start-of-sentence label is " + << LinearFstData::kStartOfSentence; + VLOG(1) << "end-of-sentence label is " << LinearFstData::kEndOfSentence; + + if (FLAGS_classifier) { + int num_classes = ScanNumClasses(models, models_length); + LinearClassifierFstDataBuilder builder(num_classes, &isyms, &fsyms, + &osyms); + + AddVocab(vocab, &isyms, &fsyms, &osyms, &builder); + for (int i = 0; i < models_length; ++i) + AddModel(models[i], &fsyms, &osyms, &builder); + + LinearClassifierFst fst(builder.Dump(), num_classes, &isyms, &osyms); + fst.Write(out); + } else { + LinearFstDataBuilder builder(&isyms, &fsyms, &osyms); + + AddVocab(vocab, &isyms, &fsyms, &osyms, &builder); + for (int i = 0; i < models_length; ++i) + AddModel(models[i], &fsyms, &osyms, &builder); + + LinearTaggerFst fst(builder.Dump(), &isyms, &osyms); + fst.Write(out); + } + + if (!save_isymbols.empty()) isyms.WriteText(save_isymbols); + if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols); + if (!save_osymbols.empty()) osyms.WriteText(save_osymbols); +} + +void LinearCompile(const string &arc_type, const string &epsilon_symbol, + const string &unknown_symbol, const string &vocab, + char **models, int models_len, const string &out, + const string &save_isymbols, const string &save_fsymbols, + const string &save_osymbols); + +template +bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT + SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, + typename Arc::Label *word, + std::vector *feature_labels, + std::vector *possible_labels, + size_t *num_line) { + string line; + if (!std::getline(strm, line)) return false; + ++(*num_line); + + std::vector fields; + SplitByWhitespace(line, &fields); + if (fields.size() != 3) + LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line " + << num_line; + + feature_labels->clear(); + possible_labels->clear(); + + *word = LookUp(fields[0], isyms); + + const char delim = FLAGS_delimiter[0]; + SplitAndPush(fields[1], delim, fsyms, feature_labels); + SplitAndPush(fields[2], delim, osyms, possible_labels); + + return true; +} + +template +bool GetModelRecord(const string &model, std::istream &strm, // NOLINT + SymbolTable *fsyms, SymbolTable *osyms, + std::vector *input_labels, + std::vector *output_labels, + typename Arc::Weight *weight, size_t *num_line) { + string line; + if (!std::getline(strm, line)) return false; + ++(*num_line); + + std::vector fields; + SplitByWhitespace(line, &fields); + if (fields.size() != 3) + LOG(FATAL) << "Wrong number of fields in source " << model << ", line " + << num_line; + + input_labels->clear(); + output_labels->clear(); + + const char delim = FLAGS_delimiter[0]; + SplitAndPush(fields[0], delim, fsyms, input_labels); + SplitAndPush(fields[1], delim, osyms, output_labels); + + *weight = StrToWeight(fields[2], model, *num_line); + + GuessStartOrEnd(input_labels, kNoLabel); + GuessStartOrEnd(output_labels, kNoLabel); + + return true; +} +} // namespace script +} // namespace fst + +#define REGISTER_FST_LINEAR_OPERATIONS(Arc) \ + REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs); + +#endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h b/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h new file mode 100644 index 00000000..1b5d2eaf --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h @@ -0,0 +1,77 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ +#define FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Applies a FST model as a discriminative model to weighted input +// `ifst`. `A` is an arc type with tropical weight of all the +// input/output FSTs. +// +// In general, consider `ifst` an unnormalized probability +// distribution between its input X and output Y, P(X, Y); and `lfst` +// a group of unnormalized probability distributions of all its output +// Z for every input Y, Q(Z|Y). `normalize` controls whether Q is +// normalized for every Y before chaining with P(X, Y). I.e., for a +// path (X, Y, Z) in `ofst` (where Y is hidden), +// +// - When `normalize` is true, its weight is P(X, Y) Q(Z|Y) / sum_z Q(z|Y); +// - When `normalize` is false, its weight is P(X, Y) Q(Z|Y). +template +void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, + bool normalize = true) { + LogLinearApply(ifst, lfst, ofst, normalize); +} + +// This version gives finer control over the arc type (`B`) to be used +// in normalization. `B` is an arc type with log weight (e.g. `LogArc` +// or `Log64Arc`). +template +void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, + bool normalize = true) { + if (normalize) { + VectorFst unnormalized_ofst, rescored_ifsa; + Compose(ifst, lfst, &unnormalized_ofst); + { + VectorFst tropical_ifsa(unnormalized_ofst); + Project(&tropical_ifsa, PROJECT_INPUT); + { + VectorFst minimal_log_ifsa; + { + VectorFst log_ifsa; + ArcMap(tropical_ifsa, &log_ifsa, WeightConvertMapper()); + RmEpsilon(&log_ifsa); + Determinize(log_ifsa, &minimal_log_ifsa); + } + Minimize(&minimal_log_ifsa); + ArcMap(&minimal_log_ifsa, InvertWeightMapper()); + ArcMap(minimal_log_ifsa, &tropical_ifsa, WeightConvertMapper()); + } + ArcSort(&tropical_ifsa, OLabelCompare()); + Compose(tropical_ifsa, ifst, &rescored_ifsa); + } + ArcSort(&rescored_ifsa, OLabelCompare()); + Compose(rescored_ifsa, unnormalized_ofst, ofst); + } else { + Compose(ifst, lfst, ofst); + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/trie.h b/projects/llm_framework/include/fst/extensions/linear/trie.h new file mode 100644 index 00000000..b5ddb3ad --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/trie.h @@ -0,0 +1,444 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_TRIE_H_ +#define FST_EXTENSIONS_LINEAR_TRIE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +const int kNoTrieNodeId = -1; + +// Forward declarations of all available trie topologies. +template +class NestedTrieTopology; +template +class FlatTrieTopology; + +// A pair of parent node id and label, part of a trie edge +template +struct ParentLabel { + int parent; + L label; + + ParentLabel() {} + ParentLabel(int p, L l) : parent(p), label(l) {} + + bool operator==(const ParentLabel &that) const { + return parent == that.parent && label == that.label; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &parent); + ReadType(strm, &label); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, parent); + WriteType(strm, label); + return strm; + } +}; + +template +struct ParentLabelHash { + size_t operator()(const ParentLabel &pl) const { + return static_cast(pl.parent * 7853 + H()(pl.label)); + } +}; + +// The trie topology in a nested tree of hash maps; allows efficient +// iteration over children of a specific node. +template +class NestedTrieTopology { + public: + typedef L Label; + typedef H Hash; + typedef std::unordered_map NextMap; + + class const_iterator { + public: + typedef std::forward_iterator_tag iterator_category; + typedef std::pair, int> value_type; + typedef std::ptrdiff_t difference_type; + typedef const value_type *pointer; + typedef const value_type &reference; + + friend class NestedTrieTopology; + + const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {} + + reference operator*() { + UpdateStub(); + return stub_; + } + pointer operator->() { + UpdateStub(); + return &stub_; + } + + const_iterator &operator++(); + const_iterator &operator++(int); // NOLINT + + bool operator==(const const_iterator &that) const { + return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ && + cur_edge_ == that.cur_edge_; + } + bool operator!=(const const_iterator &that) const { + return !(*this == that); + } + + private: + const_iterator(const NestedTrieTopology *ptr, int cur_node) + : ptr_(ptr), cur_node_(cur_node) { + SetProperCurEdge(); + } + + void SetProperCurEdge() { + if (cur_node_ < ptr_->NumNodes()) + cur_edge_ = ptr_->nodes_[cur_node_]->begin(); + else + cur_edge_ = ptr_->nodes_[0]->begin(); + } + + void UpdateStub() { + stub_.first = ParentLabel(cur_node_, cur_edge_->first); + stub_.second = cur_edge_->second; + } + + const NestedTrieTopology *ptr_; + int cur_node_; + typename NextMap::const_iterator cur_edge_; + value_type stub_; + }; + + NestedTrieTopology(); + NestedTrieTopology(const NestedTrieTopology &that); + ~NestedTrieTopology(); + void swap(NestedTrieTopology &that); + NestedTrieTopology &operator=(const NestedTrieTopology &that); + bool operator==(const NestedTrieTopology &that) const; + bool operator!=(const NestedTrieTopology &that) const; + + int Root() const { return 0; } + size_t NumNodes() const { return nodes_.size(); } + int Insert(int parent, const L &label); + int Find(int parent, const L &label) const; + const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; } + + std::istream &Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + + const_iterator begin() const { return const_iterator(this, 0); } + const_iterator end() const { return const_iterator(this, NumNodes()); } + + private: + std::vector + nodes_; // Use pointers to avoid copying the maps when the + // vector grows +}; + +template +NestedTrieTopology::NestedTrieTopology() { + nodes_.push_back(new NextMap); +} + +template +NestedTrieTopology::NestedTrieTopology(const NestedTrieTopology &that) { + nodes_.reserve(that.nodes_.size()); + for (size_t i = 0; i < that.nodes_.size(); ++i) { + NextMap *node = that.nodes_[i]; + nodes_.push_back(new NextMap(*node)); + } +} + +template +NestedTrieTopology::~NestedTrieTopology() { + for (size_t i = 0; i < nodes_.size(); ++i) { + NextMap *node = nodes_[i]; + delete node; + } +} + +// TODO(wuke): std::swap compatibility +template +inline void NestedTrieTopology::swap(NestedTrieTopology &that) { + nodes_.swap(that.nodes_); +} + +template +inline NestedTrieTopology &NestedTrieTopology::operator=( + const NestedTrieTopology &that) { + NestedTrieTopology copy(that); + swap(copy); + return *this; +} + +template +inline bool NestedTrieTopology::operator==( + const NestedTrieTopology &that) const { + if (NumNodes() != that.NumNodes()) return false; + for (int i = 0; i < NumNodes(); ++i) + if (ChildrenOf(i) != that.ChildrenOf(i)) return false; + return true; +} + +template +inline bool NestedTrieTopology::operator!=( + const NestedTrieTopology &that) const { + return !(*this == that); +} + +template +inline int NestedTrieTopology::Insert(int parent, const L &label) { + int ret = Find(parent, label); + if (ret == kNoTrieNodeId) { + ret = NumNodes(); + (*nodes_[parent])[label] = ret; + nodes_.push_back(new NextMap); + } + return ret; +} + +template +inline int NestedTrieTopology::Find(int parent, const L &label) const { + typename NextMap::const_iterator it = nodes_[parent]->find(label); + return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second; +} + +template +inline std::istream &NestedTrieTopology::Read( + std::istream &strm) { // NOLINT + NestedTrieTopology new_trie; + size_t num_nodes; + if (!ReadType(strm, &num_nodes)) return strm; + for (size_t i = 1; i < num_nodes; ++i) new_trie.nodes_.push_back(new NextMap); + for (size_t i = 0; i < num_nodes; ++i) ReadType(strm, new_trie.nodes_[i]); + if (strm) swap(new_trie); + return strm; +} + +template +inline std::ostream &NestedTrieTopology::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, NumNodes()); + for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]); + return strm; +} + +template +inline typename NestedTrieTopology::const_iterator + &NestedTrieTopology::const_iterator::operator++() { + ++cur_edge_; + if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) { + ++cur_node_; + while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty()) + ++cur_node_; + SetProperCurEdge(); + } + return *this; +} + +template +inline typename NestedTrieTopology::const_iterator + &NestedTrieTopology::const_iterator::operator++(int) { // NOLINT + const_iterator save(*this); + ++(*this); + return save; +} + +// The trie topology in a single hash map; only allows iteration over +// all the edges in arbitrary order. +template +class FlatTrieTopology { + private: + typedef std::unordered_map, int, ParentLabelHash> + NextMap; + + public: + // Iterator over edges as std::pair, int> + typedef typename NextMap::const_iterator const_iterator; + typedef L Label; + typedef H Hash; + + FlatTrieTopology() {} + FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {} + template + explicit FlatTrieTopology(const T &that); + + // TODO(wuke): std::swap compatibility + void swap(FlatTrieTopology &that) { next_.swap(that.next_); } + + bool operator==(const FlatTrieTopology &that) const { + return next_ == that.next_; + } + bool operator!=(const FlatTrieTopology &that) const { + return !(*this == that); + } + + int Root() const { return 0; } + size_t NumNodes() const { return next_.size() + 1; } + int Insert(int parent, const L &label); + int Find(int parent, const L &label) const; + + std::istream &Read(std::istream &strm) { // NOLINT + return ReadType(strm, &next_); + } + std::ostream &Write(std::ostream &strm) const { // NOLINT + return WriteType(strm, next_); + } + + const_iterator begin() const { return next_.begin(); } + const_iterator end() const { return next_.end(); } + + private: + NextMap next_; +}; + +template +template +FlatTrieTopology::FlatTrieTopology(const T &that) + : next_(that.begin(), that.end()) {} + +template +inline int FlatTrieTopology::Insert(int parent, const L &label) { + int ret = Find(parent, label); + if (ret == kNoTrieNodeId) { + ret = NumNodes(); + next_[ParentLabel(parent, label)] = ret; + } + return ret; +} + +template +inline int FlatTrieTopology::Find(int parent, const L &label) const { + typename NextMap::const_iterator it = + next_.find(ParentLabel(parent, label)); + return it == next_.end() ? kNoTrieNodeId : it->second; +} + +// A collection of implementations of the trie data structure. The key +// is a sequence of type `L` which must be hashable. The value is of +// `V` which must be default constructible and copyable. In addition, +// a value object is stored for each node in the trie therefore +// copying `V` should be cheap. +// +// One can access the store values with an integer node id, using the +// [] operator. A valid node id can be obtained by the following ways: +// +// 1. Using the `Root()` method to get the node id of the root. +// +// 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense +// so every integer in this range is a valid node id. +// +// 3. Using the node id returned from a successful `Insert()` or +// `Find()` call. +// +// 4. Iterating over the trie edges with an `EdgeIterator` and using +// the node ids returned from its `Parent()` and `Child()` methods. +// +// Below is an example of inserting keys into the trie: +// +// const string words[] = {"hello", "health", "jello"}; +// Trie dict; +// for (auto word : words) { +// int cur = dict.Root(); +// for (char c : word) { +// cur = dict.Insert(cur, c); +// } +// dict[cur] = true; +// } +// +// And the following is an example of looking up the longest prefix of +// a string using the trie constructed above: +// +// string query = "healed"; +// size_t prefix_length = 0; +// int cur = dict.Find(dict.Root(), query[prefix_length]); +// while (prefix_length < query.size() && +// cur != Trie::kNoNodeId) { +// ++prefix_length; +// cur = dict.Find(cur, query[prefix_length]); +// } +template +class MutableTrie { + public: + template + friend class MutableTrie; + + typedef L Label; + typedef V Value; + typedef T Topology; + + // Constructs a trie with only the root node. + MutableTrie() {} + + // Conversion from another trie of a possiblly different + // topology. The underlying topology must supported conversion. + template + explicit MutableTrie(const MutableTrie &that) + : topology_(that.topology_), values_(that.values_) {} + + // TODO(wuke): std::swap compatibility + void swap(MutableTrie &that) { + topology_.swap(that.topology_); + values_.swap(that.values_); + } + + int Root() const { return topology_.Root(); } + size_t NumNodes() const { return topology_.NumNodes(); } + + // Inserts an edge with given `label` at node `parent`. Returns the + // child node id. If the node already exists, returns the node id + // right away. + int Insert(int parent, const L &label) { + int ret = topology_.Insert(parent, label); + values_.resize(NumNodes()); + return ret; + } + + // Finds the node id of the node from `parent` via `label`. Returns + // `kNoTrieNodeId` when such a node does not exist. + int Find(int parent, const L &label) const { + return topology_.Find(parent, label); + } + + const T &TrieTopology() const { return topology_; } + + // Accesses the value stored for the given node. + V &operator[](int node_id) { return values_[node_id]; } + const V &operator[](int node_id) const { return values_[node_id]; } + + // Comparison by content + bool operator==(const MutableTrie &that) const { + return topology_ == that.topology_ && values_ == that.values_; + } + + bool operator!=(const MutableTrie &that) const { return !(*this == that); } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &topology_); + ReadType(strm, &values_); + return strm; + } + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, topology_); + WriteType(strm, values_); + return strm; + } + + private: + T topology_; + std::vector values_; +}; + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_TRIE_H_ diff --git a/projects/llm_framework/include/fst/extensions/mpdt/compose.h b/projects/llm_framework/include/fst/extensions/mpdt/compose.h new file mode 100644 index 00000000..47714e37 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/mpdt/compose.h @@ -0,0 +1,267 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Compose an MPDT and an FST. + +#ifndef FST_EXTENSIONS_MPDT_COMPOSE_H_ +#define FST_EXTENSIONS_MPDT_COMPOSE_H_ + +#include + +#include +#include +#include + +namespace fst { + +template +class MPdtParenFilter { + public: + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using StackId = StateId; + using ParenStack = internal::MPdtStack; + using FilterState1 = typename Filter::FilterState; + using FilterState2 = IntegerFilterState; + using FilterState = PairFilterState; + + MPdtParenFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr, + const std::vector> *parens = nullptr, + const std::vector(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + const ParenStack &GetStack() const { return GetImpl()->GetStack(); } + + const PdtStateTable &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + void operator=(const MPdtExpandFst &) = delete; +}; + +// Specialization for MPdtExpandFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const MPdtExpandFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for MPdtExpandFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const MPdtExpandFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s); + } +}; + +template +inline void MPdtExpandFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +struct MPdtExpandOptions { + bool connect; + bool keep_parentheses; + + explicit MPdtExpandOptions(bool connect = true, bool keep_parentheses = false) + : connect(connect), keep_parentheses(keep_parentheses) {} +}; + +// Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST. +// This version writes the expanded PDT to a mutable FST. In the MPDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// an MPDT, the parens for each stack must balance on a path. The open-close +// parenthesis label pair sets are passed using the parens argument, and the +// assignment of those pairs to stacks is passed using the assignments argument. +// The expansion enforces the parenthesis constraints. The MPDT must be +// expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector< + std::pair> &parens, + const std::vector &assignments, + MutableFst *ofst, const MPdtExpandOptions &opts) { + MPdtExpandFstOptions eopts; + eopts.gc_limit = 0; + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = MPdtExpandFst(ifst, parens, assignments, eopts); + if (opts.connect) Connect(ofst); +} + +// Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST. +// This version writes the expanded PDT to a mutable FST. In the MPDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// an MPDT, the parens for each stack must balance on a path. The open-close +// parenthesis label pair sets are passed using the parens argument, and the +// assignment of those pairs to stacks is passed using the assignments argument. +// The expansion enforces the parenthesis constraints. The MPDT must be +// expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector> &parens, + const std::vector &assignments, + MutableFst *ofst, bool connect = true, + bool keep_parentheses = false) { + const MPdtExpandOptions opts(connect, keep_parentheses); + Expand(ifst, parens, assignments, ofst, opts); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_MPDT_EXPAND_H_ diff --git a/projects/llm_framework/include/fst/extensions/mpdt/info.h b/projects/llm_framework/include/fst/extensions/mpdt/info.h new file mode 100644 index 00000000..512fcfa8 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/mpdt/info.h @@ -0,0 +1,190 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Prints information about an MPDT. + +#ifndef FST_EXTENSIONS_MPDT_INFO_H_ +#define FST_EXTENSIONS_MPDT_INFO_H_ + +#include +#include + +#include +#include + +namespace fst { + +// Compute various information about MPDTs, helper class for mpdtinfo.cc. +template +class MPdtInfo { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MPdtInfo(const Fst &fst, + const std::vector> &parens, + const std::vector { + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::WriteHeader; + + friend class ArcIterator>; + friend class NGramFstMatcher; + + public: + using FstImpl::InputSymbols; + using FstImpl::SetProperties; + using FstImpl::Properties; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstImpl() { + SetType("ngram"); + SetInputSymbols(nullptr); + SetOutputSymbols(nullptr); + SetProperties(kStaticProperties); + } + + NGramFstImpl(const Fst &fst, std::vector *order_out); + + explicit NGramFstImpl(const Fst &fst) : NGramFstImpl(fst, nullptr) {} + + NGramFstImpl(const NGramFstImpl &other) { + FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false."; + SetProperties(kError, kError); + } + + ~NGramFstImpl() override { + if (owned_) { + delete[] data_; + } + } + + static NGramFstImpl *Read(std::istream &strm, // NOLINT + const FstReadOptions &opts) { + NGramFstImpl *impl = new NGramFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; + uint64 num_states, num_futures, num_final; + const size_t offset = + sizeof(num_states) + sizeof(num_futures) + sizeof(num_final); + // Peek at num_states and num_futures to see how much more needs to be read. + strm.read(reinterpret_cast(&num_states), sizeof(num_states)); + strm.read(reinterpret_cast(&num_futures), sizeof(num_futures)); + strm.read(reinterpret_cast(&num_final), sizeof(num_final)); + size_t size = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast(data_region->mutable_data()); + // Copy num_states, num_futures and num_final back into data. + memcpy(data, reinterpret_cast(&num_states), sizeof(num_states)); + memcpy(data + sizeof(num_states), reinterpret_cast(&num_futures), + sizeof(num_futures)); + memcpy(data + sizeof(num_states) + sizeof(num_futures), + reinterpret_cast(&num_final), sizeof(num_final)); + strm.read(data + offset, size - offset); + if (strm.fail()) { + delete impl; + return nullptr; + } + impl->Init(data, false, data_region); + return impl; + } + + bool Write(std::ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(num_states_); + WriteHeader(strm, opts, kFileVersion, &hdr); + strm.write(data_, StorageSize()); + return !strm.fail(); + } + + StateId Start() const { return start_; } + + Weight Final(StateId state) const { + if (final_index_.Get(state)) { + return final_probs_[final_index_.Rank1(state)]; + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId state, NGramFstInst *inst = nullptr) const { + if (inst == nullptr) { + const std::pair zeros = + (state == 0) ? select_root_ : future_index_.Select0s(state); + return zeros.second - zeros.first - 1; + } + SetInstFuture(state, inst); + return inst->num_futures_ + ((state == 0) ? 0 : 1); + } + + size_t NumInputEpsilons(StateId state) const { + // State 0 has no parent, thus no backoff. + if (state == 0) return 0; + return 1; + } + + size_t NumOutputEpsilons(StateId state) const { + return NumInputEpsilons(state); + } + + StateId NumStates() const { return num_states_; } + + void InitStateIterator(StateIteratorData *data) const { + data->base = 0; + data->nstates = num_states_; + } + + static size_t Storage(uint64 num_states, uint64 num_futures, + uint64 num_final) { + uint64 b64; + Weight weight; + Label label; + size_t offset = + sizeof(num_states) + sizeof(num_futures) + sizeof(num_final); + offset += + sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) + + BitmapIndex::StorageSize(num_futures + num_states + 1) + + BitmapIndex::StorageSize(num_states)); + offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); + // Pad for alignemnt, see + // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + + (num_futures + 1) * sizeof(weight); + return offset; + } + + void SetInstFuture(StateId state, NGramFstInst *inst) const { + if (inst->state_ != state) { + inst->state_ = state; + const std::pair zeros = future_index_.Select0s(state); + inst->num_futures_ = zeros.second - zeros.first - 1; + inst->offset_ = future_index_.Rank1(zeros.first + 1); + } + } + + void SetInstNode(NGramFstInst *inst) const { + if (inst->node_state_ != inst->state_) { + inst->node_state_ = inst->state_; + inst->node_ = context_index_.Select1(inst->state_); + } + } + + void SetInstContext(NGramFstInst *inst) const { + SetInstNode(inst); + if (inst->context_state_ != inst->state_) { + inst->context_state_ = inst->state_; + inst->context_.clear(); + size_t node = inst->node_; + while (node != 0) { + inst->context_.push_back(context_words_[context_index_.Rank1(node)]); + node = context_index_.Select1(context_index_.Rank0(node) - 1); + } + } + } + + // Access to the underlying representation + const char *GetData(size_t *data_size) const { + *data_size = StorageSize(); + return data_; + } + + void Init(const char *data, bool owned, MappedFile *file = nullptr); + + const std::vector *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + size_t StorageSize() const { + return Storage(num_states_, num_futures_, num_final_); + } + + void GetStates(const std::vector::GetStates( + const std::vector; + + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef internal::NGramFstImpl Impl; + + explicit NGramFst(const Fst &dst) + : ImplToExpandedFst(std::make_shared(dst, nullptr)) {} + + NGramFst(const Fst &fst, std::vector *order_out) + : ImplToExpandedFst(std::make_shared(fst, order_out)) {} + + // Because the NGramFstImpl is a const stateless data structure, there + // is never a need to do anything beside copy the reference. + NGramFst(const NGramFst &fst, bool safe = false) + : ImplToExpandedFst(fst, false) {} + + NGramFst() : ImplToExpandedFst(std::make_shared()) {} + + // Non-standard constructor to initialize NGramFst directly from data. + NGramFst(const char *data, bool owned) + : ImplToExpandedFst(std::make_shared()) { + GetMutableImpl()->Init(data, owned, nullptr); + } + + // Get method that gets the data associated with Init(). + const char *GetData(size_t *data_size) const { + return GetImpl()->GetData(data_size); + } + + const std::vector *Copy(bool safe = false) const override { + return new NGramFst(*this, safe); + } + + static NGramFst *Read(std::istream &strm, const FstReadOptions &opts) { + Impl *impl = Impl::Read(strm, opts); + return impl ? new NGramFst(std::shared_ptr(impl)) : nullptr; + } + + static NGramFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm.good()) { + LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + inline void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + inline void InitArcIterator(StateId s, + ArcIteratorData *data) const override; + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new NGramFstMatcher(this, match_type); + } + + size_t StorageSize() const { return GetImpl()->StorageSize(); } + + static bool HasRequiredProps(const Fst &fst) { + static const auto props = + kAcceptor | kIDeterministic | kILabelSorted | kIEpsilons | kAccessible; + return fst.Properties(props, true) == props; + } + + static bool HasRequiredStructure(const Fst &fst) { + if (!HasRequiredProps(fst)) { + return false; + } + typename A::StateId unigram = fst.Start(); + while (true) { // Follows epsilon arc chain to find unigram state. + if (unigram == fst::kNoStateId) return false; // No unigram state. + typename fst::ArcIterator> aiter(fst, unigram); + if (aiter.Done() || aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + aiter.Next(); + } + // Other requirement: all states other than unigram an epsilon arc. + for (fst::StateIterator> siter(fst); !siter.Done(); + siter.Next()) { + const typename A::StateId &state = siter.Value(); + fst::ArcIterator> aiter(fst, state); + if (state != unigram) { + if (aiter.Done()) return false; + if (aiter.Value().ilabel != 0) return false; + aiter.Next(); + if (!aiter.Done() && aiter.Value().ilabel == 0) return false; + } + } + return true; + } + + private: + using ImplToExpandedFst>::GetImpl; + using ImplToExpandedFst>::GetMutableImpl; + + explicit NGramFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + mutable NGramFstInst inst_; +}; + +template +inline void NGramFst::InitArcIterator(StateId s, + ArcIteratorData *data) const { + GetImpl()->SetInstFuture(s, &inst_); + GetImpl()->SetInstNode(&inst_); + data->base = new ArcIterator>(*this, s); +} + +namespace internal { + +template +NGramFstImpl::NGramFstImpl(const Fst &fst, + std::vector *order_out) { + typedef A Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + SetType("ngram"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + SetProperties(kStaticProperties); + + // Check basic requirements for an OpenGrm language model Fst. + if (!NGramFst::HasRequiredProps(fst)) { + FSTERROR() << "NGramFst only accepts OpenGrm language models as input"; + SetProperties(kError, kError); + return; + } + + int64 num_states = CountStates(fst); + Label *context = new Label[num_states]; + + // Find the unigram state by starting from the start state, following + // epsilons. + StateId unigram = fst.Start(); + while (1) { + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state"; + SetProperties(kError, kError); + return; + } + ArcIterator> aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } + if (aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + } + + // Each state's context is determined by the subtree it is under from the + // unigram state. + std::queue> label_queue; + std::vector visited(num_states); + // Force an epsilon link to the start state. + label_queue.push(std::make_pair(fst.Start(), 0)); + for (ArcIterator> aiter(fst, unigram); !aiter.Done(); aiter.Next()) { + label_queue.push( + std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); + } + // investigate states in breadth first fashion to assign context words. + while (!label_queue.empty()) { + std::pair &now = label_queue.front(); + if (!visited[now.first]) { + context[now.first] = now.second; + visited[now.first] = true; + for (ArcIterator> aiter(fst, now.first); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + label_queue.push(std::make_pair(arc.nextstate, now.second)); + } + } + } + label_queue.pop(); + } + visited.clear(); + + // The arc from the start state should be assigned an epsilon to put it + // in front of the all other labels (which makes Start state 1 after + // unigram which is state 0). + context[fst.Start()] = 0; + + // Build the tree of contexts fst by reversing the epsilon arcs from fst. + VectorFst context_fst; + uint64 num_final = 0; + for (int i = 0; i < num_states; ++i) { + if (fst.Final(i) != Weight::Zero()) { + ++num_final; + } + context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); + } + context_fst.SetStart(unigram); + context_fst.SetInputSymbols(fst.InputSymbols()); + context_fst.SetOutputSymbols(fst.OutputSymbols()); + int64 num_context_arcs = 0; + int64 num_futures = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + const StateId &state = siter.Value(); + num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); + ArcIterator> aiter(fst, state); + if (!aiter.Done()) { + const Arc &arc = aiter.Value(); + // this arc goes from state to arc.nextstate, so create an arc from + // arc.nextstate to state to reverse it. + if (arc.ilabel == 0) { + context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], + arc.weight, state)); + num_context_arcs++; + } + } + } + if (num_context_arcs != context_fst.NumStates() - 1) { + FSTERROR() << "Number of contexts arcs != number of states - 1"; + SetProperties(kError, kError); + return; + } + if (context_fst.NumStates() != num_states) { + FSTERROR() << "Number of contexts != number of states"; + SetProperties(kError, kError); + return; + } + int64 context_props = + context_fst.Properties(kIDeterministic | kILabelSorted, true); + if (!(context_props & kIDeterministic)) { + FSTERROR() << "Input Fst is not structured properly"; + SetProperties(kError, kError); + return; + } + if (!(context_props & kILabelSorted)) { + ArcSort(&context_fst, ILabelCompare()); + } + + delete[] context; + + uint64 b64; + Weight weight; + Label label = kNoLabel; + const size_t storage = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast(data_region->mutable_data()); + memset(data, 0, storage); + size_t offset = 0; + memcpy(data + offset, reinterpret_cast(&num_states), + sizeof(num_states)); + offset += sizeof(num_states); + memcpy(data + offset, reinterpret_cast(&num_futures), + sizeof(num_futures)); + offset += sizeof(num_futures); + memcpy(data + offset, reinterpret_cast(&num_final), + sizeof(num_final)); + offset += sizeof(num_final); + uint64 *context_bits = reinterpret_cast(data + offset); + offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); + uint64 *future_bits = reinterpret_cast(data + offset); + offset += + BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); + uint64 *final_bits = reinterpret_cast(data + offset); + offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); + Label *context_words = reinterpret_cast::Init(const char *data, bool owned, + MappedFile *data_region) { + if (owned_) { + delete[] data_; + } + data_region_.reset(data_region); + owned_ = owned; + data_ = data; + size_t offset = 0; + num_states_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_states_); + num_futures_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_futures_); + num_final_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_final_); + uint64 bits; + size_t context_bits = num_states_ * 2 + 1; + size_t future_bits = num_futures_ + num_states_ + 1; + context_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); + future_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); + final_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); + context_words_ = reinterpret_cast(data_ + offset); + offset += (num_states_ + 1) * sizeof(*context_words_); + future_words_ = reinterpret_cast(data_ + offset); + offset += num_futures_ * sizeof(*future_words_); + offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); + backoff_ = reinterpret_cast(data_ + offset); + offset += (num_states_ + 1) * sizeof(*backoff_); + final_probs_ = reinterpret_cast(data_ + offset); + offset += num_final_ * sizeof(*final_probs_); + future_probs_ = reinterpret_cast(data_ + offset); + + context_index_.BuildIndex(context_, context_bits); + future_index_.BuildIndex(future_, future_bits); + final_index_.BuildIndex(final_, num_states_); + + select_root_ = context_index_.Select0s(0); + if (context_index_.Rank1(0) != 0 || select_root_.first != 1 || + context_index_.Get(2) == false) { + FSTERROR() << "Malformed file"; + SetProperties(kError, kError); + return; + } + root_children_ = context_words_ + context_index_.Rank1(2); + start_ = 1; +} + +template +inline typename A::StateId NGramFstImpl::Transition( + const std::vector { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + // This makes a copy of the FST. + NGramFstMatcher(const NGramFst &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + inst_(fst_.inst_), + match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + // This doesn't copy the FST. + NGramFstMatcher(const NGramFst *fst, MatchType match_type) + : fst_(*fst), + inst_(fst_.inst_), + match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + // This makes a copy of the FST. + NGramFstMatcher(const NGramFstMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + inst_(matcher.inst_), + match_type_(matcher.match_type_), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + NGramFstMatcher *Copy(bool safe = false) const override { + return new NGramFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return match_type_; } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { return props; } + + void SetState(StateId s) final { + fst_.GetImpl()->SetInstFuture(s, &inst_); + current_loop_ = false; + } + + bool Find(Label label) final { + const Label nolabel = kNoLabel; + done_ = true; + if (label == 0 || label == nolabel) { + if (label == 0) { + current_loop_ = true; + loop_.nextstate = inst_.state_; + } + // The unigram state has no epsilon arc. + if (inst_.state_ != 0) { + arc_.ilabel = arc_.olabel = 0; + fst_.GetImpl()->SetInstNode(&inst_); + arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( + fst_.GetImpl()->context_index_.Select1( + fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); + arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; + done_ = false; + } + } else { + current_loop_ = false; + const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; + const Label *end = start + inst_.num_futures_; + const Label *search = std::lower_bound(start, end, label); + if (search != end && *search == label) { + size_t state = search - start; + arc_.ilabel = arc_.olabel = label; + arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; + fst_.GetImpl()->SetInstContext(&inst_); + arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); + done_ = false; + } + } + return !Done(); + } + + bool Done() const final { return !current_loop_ && done_; } + + const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + done_ = true; + } + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + std::unique_ptr> owned_fst_; + const NGramFst &fst_; + NGramFstInst inst_; + MatchType match_type_; // Supplied by caller + bool done_; + Arc arc_; + bool current_loop_; // Current arc is the implicit loop + Arc loop_; +}; + +/*****************************************************************************/ +// Specialization for NGramFst; see generic version in fst.h +// for sample usage (but use the ProdLmFst type!). This version +// should inline. +template +class StateIterator> : public StateIteratorBase { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const NGramFst &fst) + : s_(0), num_states_(fst.NumStates()) {} + + bool Done() const final { return s_ >= num_states_; } + + StateId Value() const final { return s_; } + + void Next() final { ++s_; } + + void Reset() final { s_ = 0; } + + private: + StateId s_; + StateId num_states_; +}; + +/*****************************************************************************/ +template +class ArcIterator> : public ArcIteratorBase { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ArcIterator(const NGramFst &fst, StateId state) + : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { + inst_ = fst.inst_; + impl_->SetInstFuture(state, &inst_); + impl_->SetInstNode(&inst_); + } + + bool Done() const final { + return i_ >= + ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1); + } + + const Arc &Value() const final { + bool eps = (inst_.node_ != 0 && i_ == 0); + StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; + if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { + arc_.ilabel = arc_.olabel = + eps ? 0 : impl_->future_words_[inst_.offset_ + state]; + lazy_ &= ~(kArcILabelValue | kArcOLabelValue); + } + if (flags_ & lazy_ & kArcNextStateValue) { + if (eps) { + arc_.nextstate = + impl_->context_index_.Rank1(impl_->context_index_.Select1( + impl_->context_index_.Rank0(inst_.node_) - 1)); + } else { + if (lazy_ & kArcNextStateValue) { + impl_->SetInstContext(&inst_); // first time only. + } + arc_.nextstate = impl_->Transition( + inst_.context_, impl_->future_words_[inst_.offset_ + state]); + } + lazy_ &= ~kArcNextStateValue; + } + if (flags_ & lazy_ & kArcWeightValue) { + arc_.weight = eps ? impl_->backoff_[inst_.state_] + : impl_->future_probs_[inst_.offset_ + state]; + lazy_ &= ~kArcWeightValue; + } + return arc_; + } + + void Next() final { + ++i_; + lazy_ = ~0; + } + + size_t Position() const final { return i_; } + + void Reset() final { + i_ = 0; + lazy_ = ~0; + } + + void Seek(size_t a) final { + if (i_ != a) { + i_ = a; + lazy_ = ~0; + } + } + + uint32 Flags() const final { return flags_; } + + void SetFlags(uint32 flags, uint32 mask) final { + flags_ &= ~mask; + flags_ |= (flags & kArcValueFlags); + } + + private: + mutable Arc arc_; + mutable uint32 lazy_; + const internal::NGramFstImpl *impl_; // Borrowed reference. + mutable NGramFstInst inst_; + + size_t i_; + uint32 flags_; +}; + +} // namespace fst +#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ diff --git a/projects/llm_framework/include/fst/extensions/ngram/nthbit.h b/projects/llm_framework/include/fst/extensions/ngram/nthbit.h new file mode 100644 index 00000000..1e6ec635 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/ngram/nthbit.h @@ -0,0 +1,49 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_ +#define FST_EXTENSIONS_NGRAM_NTHBIT_H_ + +#include +#include + +#ifdef __BMI2__ +// PDEP requires BMI2. + +// Returns the position (0-63) of the r-th 1 bit in v. +// 1 <= r <= CountOnes(v) <= 64. Therefore, v must not be 0. +inline uint32 nth_bit(uint64 v, uint32 r) { + // PDEP example from https://stackoverflow.com/a/27453505 + return __builtin_ctzll(_pdep_u64(uint64{1} << (r - 1), v)); +} + +#else // !defined(__BMI2__) + +extern const uint32 nth_bit_bit_offset[]; + +// Returns the position (0-63) of the r-th 1 bit in v. +// 1 <= r <= CountOnes(v) <= 64. Therefore, v must not be 0. +inline uint32 nth_bit(uint64 v, uint32 r) { + uint32 shift = 0; + uint32 c = __builtin_popcount(v & 0xffffffff); + uint32 mask = -(r > c); + r -= c & mask; + shift += (32 & mask); + + c = __builtin_popcount((v >> shift) & 0xffff); + mask = -(r > c); + r -= c & mask; + shift += (16 & mask); + + c = __builtin_popcount((v >> shift) & 0xff); + mask = -(r > c); + r -= c & mask; + shift += (8 & mask); + + return shift + + ((nth_bit_bit_offset[(v >> shift) & 0xff] >> ((r - 1) << 2)) & 0xf); +} + +#endif // !defined(__BMI2__) + +#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/collection.h b/projects/llm_framework/include/fst/extensions/pdt/collection.h new file mode 100644 index 00000000..ae34aba7 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/collection.h @@ -0,0 +1,107 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to store a collection of ordered (multi-)sets with elements of type T. + +#ifndef FST_EXTENSIONS_PDT_COLLECTION_H_ +#define FST_EXTENSIONS_PDT_COLLECTION_H_ + +#include +#include + +#include +#include + +namespace fst { + +// Stores a collection of non-empty, ordered (multi-)sets with elements of type +// T. A default constructor, operator==, and an STL-style hash functor must be +// defined on the elements. Provides signed integer ID (of type I) for each +// unique set. The IDs are allocated starting from 0 in order. +template +class Collection { + public: + struct Node { // Trie node. + I node_id; // Root is kNoNodeId; + T element; + + Node() : node_id(kNoNodeId), element(T()) {} + + Node(I i, const T &t) : node_id(i), element(t) {} + + bool operator==(const Node &n) const { + return n.node_id == node_id && n.element == element; + } + }; + + struct NodeHash { + size_t operator()(const Node &n) const { + static constexpr auto kPrime = 7853; + return n.node_id + hash_(n.element) * kPrime; + } + }; + + using NodeTable = CompactHashBiTable; + + class SetIterator { + public: + SetIterator(I id, Node node, NodeTable *node_table) + : id_(id), node_(node), node_table_(node_table) {} + + bool Done() const { return id_ == kNoNodeId; } + + const T &Element() const { return node_.element; } + + void Next() { + id_ = node_.node_id; + if (id_ != kNoNodeId) node_ = node_table_->FindEntry(id_); + } + + private: + I id_; // Iterator set node ID. + Node node_; // Iterator set node. + NodeTable *node_table_; + }; + + Collection() {} + + // Looks up integer ID from ordered multi-se, and if it doesn't exist and + // insert is true, then adds it. Otherwise returns -1. + I FindId(const std::vector &set, bool insert = true) { + I node_id = kNoNodeId; + for (ssize_t i = set.size() - 1; i >= 0; --i) { + Node node(node_id, set[i]); + node_id = node_table_.FindId(node, insert); + if (node_id == -1) break; + } + return node_id; + } + + // Finds ordered (multi-)set given integer ID. Returns set iterator to + // traverse result. + SetIterator FindSet(I id) { + if (id < 0 || id >= node_table_.Size()) { + return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_); + } else { + return SetIterator(id, node_table_.FindEntry(id), &node_table_); + } + } + + I Size() const { return node_table_.Size(); } + + private: + static constexpr I kNoNodeId = -1; + static const std::hash hash_; + + NodeTable node_table_; +}; + +template +constexpr I Collection::kNoNodeId; + +template +const std::hash Collection::hash_ = {}; + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COLLECTION_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/compose.h b/projects/llm_framework/include/fst/extensions/pdt/compose.h new file mode 100644 index 00000000..525d613a --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/compose.h @@ -0,0 +1,493 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Composes a PDT and an FST. + +#ifndef FST_EXTENSIONS_PDT_COMPOSE_H_ +#define FST_EXTENSIONS_PDT_COMPOSE_H_ + +#include + +#include +#include + +namespace fst { + +// Returns paren arcs for Find(kNoLabel). +constexpr uint32 kParenList = 0x00000001; + +// Returns a kNolabel loop for Find(paren). +constexpr uint32 kParenLoop = 0x00000002; + +// This class is a matcher that treats parens as multi-epsilon labels. +// It is most efficient if the parens are in a range non-overlapping with +// the non-paren labels. +template +class ParenMatcher { + public: + using FST = F; + using M = SortedMatcher; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + ParenMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), match_type_(match_type), flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + // This doesn't copy the FST. + ParenMatcher(const FST *fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), match_type_(match_type), flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + // This makes a copy of the FST. + ParenMatcher(const ParenMatcher &matcher, bool safe = false) + : matcher_(matcher.matcher_, safe), + match_type_(matcher.match_type_), + flags_(matcher.flags_), + open_parens_(matcher.open_parens_), + close_parens_(matcher.close_parens_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ParenMatcher *Copy(bool safe = false) const { + return new ParenMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + matcher_.SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { return done_; } + + const Arc &Value() const { return paren_loop_ ? loop_ : matcher_.Value(); } + + void Next(); + + Weight Final(StateId s) { return matcher_.Final(s); } + + ssize_t Priority(StateId s) { return matcher_.Priority(s); } + + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + + uint32 Flags() const { return matcher_.Flags(); } + + void AddOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Insert(label); + } + } + + void AddCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Insert(label); + } + } + + void RemoveOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Erase(label); + } + } + + void RemoveCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Erase(label); + } + } + + void ClearOpenParens() { open_parens_.Clear(); } + + void ClearCloseParens() { close_parens_.Clear(); } + + bool IsOpenParen(Label label) const { return open_parens_.Member(label); } + + bool IsCloseParen(Label label) const { return close_parens_.Member(label); } + + private: + // Advances matcher to next open paren, returning true if it exists. + bool NextOpenParen(); + + // Advances matcher to next close paren, returning true if it exists. + bool NextCloseParen(); + + M matcher_; + MatchType match_type_; // Type of match to perform. + uint32 flags_; + // Open paren label set. + CompactSet open_parens_; + // Close paren label set. + CompactSet close_parens_; + bool open_paren_list_; // Matching open paren list? + bool close_paren_list_; // Matching close paren list? + bool paren_loop_; // Current arc is the implicit paren loop? + mutable Arc loop_; // For non-consuming symbols. + bool done_; // Matching done? + + ParenMatcher &operator=(const ParenMatcher &) = delete; +}; + +template +inline bool ParenMatcher::Find(Label match_label) { + open_paren_list_ = false; + close_paren_list_ = false; + paren_loop_ = false; + done_ = false; + // Returns all parenthesis arcs. + if (match_label == kNoLabel && (flags_ & kParenList)) { + if (open_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(open_parens_.LowerBound()); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return true; + } + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return true; + } + } + // Returns the implicit paren loop. + if (match_label > 0 && (flags_ & kParenLoop) && + (IsOpenParen(match_label) || IsCloseParen(match_label))) { + paren_loop_ = true; + return true; + } + // Returns all other labels. + if (matcher_.Find(match_label)) return true; + done_ = true; + return false; +} + +template +inline void ParenMatcher::Next() { + if (paren_loop_) { + paren_loop_ = false; + done_ = true; + } else if (open_paren_list_) { + matcher_.Next(); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return; + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + } + done_ = !matcher_.Find(kNoLabel); + } else if (close_paren_list_) { + matcher_.Next(); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + done_ = !matcher_.Find(kNoLabel); + } else { + matcher_.Next(); + done_ = matcher_.Done(); + } +} + +// Advances matcher to next open paren, returning true if it exists. +template +inline bool ParenMatcher::NextOpenParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? matcher_.Value().ilabel + : matcher_.Value().olabel; + if (label > open_parens_.UpperBound()) return false; + if (IsOpenParen(label)) return true; + } + return false; +} + +// Advances matcher to next close paren, returning true if it exists. +template +inline bool ParenMatcher::NextCloseParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? matcher_.Value().ilabel + : matcher_.Value().olabel; + if (label > close_parens_.UpperBound()) return false; + if (IsCloseParen(label)) return true; + } + return false; +} + +template +class ParenFilter { + public: + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using StackId = StateId; + using ParenStack = PdtStack; + using FilterState1 = typename Filter::FilterState; + using FilterState2 = IntegerFilterState; + using FilterState = PairFilterState; + + ParenFilter(const FST1 &fst1, const FST2 &fst2, Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr, + const std::vector> *parens = nullptr, + bool expand = false, bool keep_parens = true) + : filter_(fst1, fst2, matcher1, matcher2), + parens_(parens ? *parens : std::vector>()), + expand_(expand), + keep_parens_(keep_parens), + fs_(FilterState::NoState()), + stack_(parens_), + paren_id_(-1) { + if (parens) { + for (const auto &pair : *parens) { + parens_.push_back(pair); + GetMatcher1()->AddOpenParen(pair.first); + GetMatcher2()->AddOpenParen(pair.first); + if (!expand_) { + GetMatcher1()->AddCloseParen(pair.second); + GetMatcher2()->AddCloseParen(pair.second); + } + } + } + } + + ParenFilter(const ParenFilter &filter, bool safe = false) + : filter_(filter.filter_, safe), + parens_(filter.parens_), + expand_(filter.expand_), + keep_parens_(filter.keep_parens_), + fs_(FilterState::NoState()), + stack_(filter.parens_), + paren_id_(-1) {} + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(0)); + } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + fs_ = fs; + filter_.SetState(s1, s2, fs_.GetState1()); + if (!expand_) return; + ssize_t paren_id = stack_.Top(fs.GetState2().GetState()); + if (paren_id != paren_id_) { + if (paren_id_ != -1) { + GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); + GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); + } + paren_id_ = paren_id; + if (paren_id_ != -1) { + GetMatcher1()->AddCloseParen(parens_[paren_id_].second); + GetMatcher2()->AddCloseParen(parens_[paren_id_].second); + } + } + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + const auto fs1 = filter_.FilterArc(arc1, arc2); + const auto &fs2 = fs_.GetState2(); + if (fs1 == FilterState1::NoState()) return FilterState::NoState(); + if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses. + if (keep_parens_) { + arc1->ilabel = arc2->ilabel; + } else if (arc2->ilabel) { + arc2->olabel = arc1->ilabel; + } + return FilterParen(arc2->ilabel, fs1, fs2); + } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses. + if (keep_parens_) { + arc2->olabel = arc1->olabel; + } else { + arc1->ilabel = arc2->olabel; + } + return FilterParen(arc1->olabel, fs1, fs2); + } else { + return FilterState(fs1, fs2); + } + } + + void FilterFinal(Weight *w1, Weight *w2) const { + if (fs_.GetState2().GetState() != 0) *w1 = Weight::Zero(); + filter_.FilterFinal(w1, w2); + } + + // Returns respective matchers; ownership stays with filter. + + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + return filter_.Properties(iprops) & kILabelInvariantProperties & + kOLabelInvariantProperties; + } + + private: + const FilterState FilterParen(Label label, const FilterState1 &fs1, + const FilterState2 &fs2) const { + if (!expand_) return FilterState(fs1, fs2); + const auto stack_id = stack_.Find(fs2.GetState(), label); + if (stack_id < 0) { + return FilterState::NoState(); + } else { + return FilterState(fs1, FilterState2(stack_id)); + } + } + + Filter filter_; + std::vector> parens_; + bool expand_; // Expands to FST? + bool keep_parens_; // Retains parentheses in output? + FilterState fs_; // Current filter state. + mutable ParenStack stack_; + ssize_t paren_id_; +}; + +// Class to setup composition options for PDT composition. Default is to take +// the PDT as the first composition argument. +template +class PdtComposeFstOptions + : public ComposeFstOptions< + Arc, ParenMatcher>, + ParenFilter>>>> { + public: + using Label = typename Arc::Label; + using PdtMatcher = ParenMatcher>; + using PdtFilter = ParenFilter>; + + using ComposeFstOptions::matcher1; + using ComposeFstOptions::matcher2; + using ComposeFstOptions::filter; + + PdtComposeFstOptions(const Fst &ifst1, + const std::vector> &parens, + const Fst &ifst2, bool expand = false, + bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand, + keep_parens); + } +}; + +// Class to setup composition options for PDT with FST composition. +// Specialization is for the FST as the first composition argument. +template +class PdtComposeFstOptions + : public ComposeFstOptions< + Arc, ParenMatcher>, + ParenFilter>>>> { + public: + using Label = typename Arc::Label; + using PdtMatcher = ParenMatcher>; + using PdtFilter = ParenFilter>; + + using ComposeFstOptions::matcher1; + using ComposeFstOptions::matcher2; + using ComposeFstOptions::filter; + + PdtComposeFstOptions(const Fst &ifst1, const Fst &ifst2, + const std::vector> &parens, + bool expand = false, bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand, + keep_parens); + } +}; + +enum PdtComposeFilter { + PAREN_FILTER, // Bar-Hillel construction; keeps parentheses. + EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses. + EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses. +}; + +struct PdtComposeOptions { + bool connect; // Connect output? + PdtComposeFilter filter_type; // Pre-defined filter to use. + + explicit PdtComposeOptions(bool connect = true, + PdtComposeFilter filter_type = PAREN_FILTER) + : connect(connect), filter_type(filter_type) {} +}; + +// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and an FST +// (2nd arg) with the result also a PDT encoded as an FST (3rd arg). In the +// PDTs, some transitions are labeled with open or close parentheses. To be +// interpreted as a PDT, the parens must balance on a path (see PdtExpand()). +// The open-close parenthesis label pairs are passed using the parens argument. +template +void Compose(const Fst &ifst1, + const std::vector< + std::pair> &parens, + const Fst &ifst2, MutableFst *ofst, + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, parens, ifst2, expand, + keep_parens); + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + if (opts.connect) Connect(ofst); +} + +// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as an FST +// (2nd arg) with the result also a PDT encoded as an FST (3rd arg). In the +// PDTs, some transitions are labeled with open or close parentheses. To be +// interpreted as a PDT, the parens must balance on a path (see ExpandFst()). +// The open-close parenthesis label pairs are passed using the parens argument. +template +void Compose(const Fst &ifst1, const Fst &ifst2, + const std::vector< + std::pair> &parens, + MutableFst *ofst, + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, ifst2, parens, expand, + keep_parens); + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/expand.h b/projects/llm_framework/include/fst/extensions/pdt/expand.h new file mode 100644 index 00000000..eeee781e --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/expand.h @@ -0,0 +1,933 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expands a PDT to an FST. + +#ifndef FST_EXTENSIONS_PDT_EXPAND_H_ +#define FST_EXTENSIONS_PDT_EXPAND_H_ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +template +struct PdtExpandFstOptions : public CacheOptions { + bool keep_parentheses; + PdtStack *stack; + PdtStateTable *state_table; + + explicit PdtExpandFstOptions( + const CacheOptions &opts = CacheOptions(), bool keep_parentheses = false, + PdtStack *stack = nullptr, + PdtStateTable *state_table = + nullptr) + : CacheOptions(opts), + keep_parentheses(keep_parentheses), + stack(stack), + state_table(state_table) {} +}; + +namespace internal { + +// Implementation class for PdtExpandFst. +template +class PdtExpandFstImpl : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using StateTuple = PdtStateTuple; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + PdtExpandFstImpl(const Fst &fst, + const std::vector> &parens, + const PdtExpandFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + stack_(opts.stack ? opts.stack : new PdtStack(parens)), + state_table_(opts.state_table ? opts.state_table + : new PdtStateTable()), + own_stack_(opts.stack == 0), + own_state_table_(opts.state_table == 0), + keep_parentheses_(opts.keep_parentheses) { + SetType("expand"); + const auto props = fst.Properties(kFstProperties, false); + SetProperties(PdtExpandProperties(props), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + PdtExpandFstImpl(const PdtExpandFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + stack_(new PdtStack(*impl.stack_)), + state_table_(new PdtStateTable()), + own_stack_(true), + own_state_table_(true), + keep_parentheses_(impl.keep_parentheses_) { + SetType("expand"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~PdtExpandFstImpl() override { + if (own_stack_) delete stack_; + if (own_state_table_) delete state_table_; + } + + StateId Start() { + if (!HasStart()) { + const auto s = fst_->Start(); + if (s == kNoStateId) return kNoStateId; + StateTuple tuple(s, 0); + const auto start = state_table_->FindState(tuple); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const auto &tuple = state_table_->Tuple(s); + const auto weight = fst_->Final(tuple.state_id); + if (weight != Weight::Zero() && tuple.stack_id == 0) + SetFinal(s, weight); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) ExpandState(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void ExpandState(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + for (ArcIterator> aiter(*fst_, tuple.state_id); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel); + if (stack_id == -1) { // Non-matching close parenthesis. + continue; + } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { + // Stack push/pop. + arc.ilabel = 0; + arc.olabel = 0; + } + StateTuple ntuple(arc.nextstate, stack_id); + arc.nextstate = state_table_->FindState(ntuple); + PushArc(s, arc); + } + SetArcs(s); + } + + const PdtStack &GetStack() const { return *stack_; } + + const PdtStateTable &GetStateTable() const { + return *state_table_; + } + + private: + // Properties for an expanded PDT. + inline uint64 PdtExpandProperties(uint64 inprops) { + return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); + } + + std::unique_ptr> fst_; + PdtStack *stack_; + PdtStateTable *state_table_; + bool own_stack_; + bool own_state_table_; + bool keep_parentheses_; +}; + +} // namespace internal + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version is a delayed FST. In the PDT, some transitions are labeled with open +// or close parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed using the parens +// argument. The expansion enforces the parenthesis constraints. The PDT must be +// expandable as an FST. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. +template +class PdtExpandFst : public ImplToFst> { + public: + using Arc = A; + + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::PdtExpandFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + PdtExpandFst(const Fst &fst, + const std::vector> &parens) + : ImplToFst( + std::make_shared(fst, parens, PdtExpandFstOptions())) {} + + PdtExpandFst(const Fst &fst, + const std::vector> &parens, + const PdtExpandFstOptions &opts) + : ImplToFst(std::make_shared(fst, parens, opts)) {} + + // See Fst<>::Copy() for doc. + PdtExpandFst(const PdtExpandFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this ExpandFst. See Fst<>::Copy() for further doc. + PdtExpandFst *Copy(bool safe = false) const override { + return new PdtExpandFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + const PdtStack &GetStack() const { + return GetImpl()->GetStack(); + } + + const PdtStateTable &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + void operator=(const PdtExpandFst &) = delete; +}; + +// Specialization for PdtExpandFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const PdtExpandFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for PdtExpandFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const PdtExpandFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s); + } +}; + +template +inline void PdtExpandFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// PrunedExpand prunes the delayed expansion of a pushdown transducer (PDT) +// encoded as an FST into an FST. In the PDT, some transitions are labeled with +// open or close parentheses. To be interpreted as a PDT, the parens must +// balance on a path. The open-close parenthesis label pairs are passed +// using the parens argument. The expansion enforces the parenthesis +// constraints. +// +// The algorithm works by visiting the delayed ExpandFst using a shortest-stack +// first queue discipline and relies on the shortest-distance information +// computed using a reverse shortest-path call to perform the pruning. +// +// The algorithm maintains the same state ordering between the ExpandFst being +// visited (efst_) and the result of pruning written into the MutableFst (ofst_) +// to improve readability. +template +class PdtPrunedExpand { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using Stack = PdtStack; + using StateTable = PdtStateTable; + using SetIterator = typename internal::PdtBalanceData::SetIterator; + + // Constructor taking as input a PDT specified by by an input FST and a vector + // of parentheses. The keep_parentheses argument specifies whether parentheses + // are replaced by epsilons or not during the expansion. The cache options are + // passed to the underlying ExpandFst. + PdtPrunedExpand(const Fst &ifst, + const std::vector> &parens, + bool keep_parentheses = false, + const CacheOptions &opts = CacheOptions()) + : ifst_(ifst.Copy()), + keep_parentheses_(keep_parentheses), + stack_(parens), + efst_(ifst, parens, + PdtExpandFstOptions(opts, true, &stack_, &state_table_)), + queue_(state_table_, stack_, stack_length_, distance_, fdistance_), + error_(false) { + Reverse(*ifst_, parens, &rfst_); + VectorFst path; + reverse_shortest_path_.reset(new PdtShortestPath>( + rfst_, parens, + PdtShortestPathOptions>(true, false))); + reverse_shortest_path_->ShortestPath(&path); + error_ = (path.Properties(kError, true) == kError); + balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse( + rfst_.NumStates(), 10, -1)); + InitCloseParenMultimap(parens); + } + + bool Error() const { return error_; } + + // Expands and prunes the input PDT according to the provided weight + // threshold, wirting the result into an output mutable FST. + void Expand(MutableFst *ofst, const Weight &threshold); + + private: + static constexpr uint8 kEnqueued = 0x01; + static constexpr uint8 kExpanded = 0x02; + static constexpr uint8 kSourceState = 0x04; + + // Comparison functor used by the queue: + // + // 1. States corresponding to shortest stack first, and + // 2. for stacks of matching length, reverse lexicographic order is used, and + // 3. for states with the same stack, shortest-first order is used. + class StackCompare { + public: + StackCompare(const StateTable &state_table, const Stack &stack, + const std::vector &stack_length, + const std::vector &distance, + const std::vector &fdistance) + : state_table_(state_table), + stack_(stack), + stack_length_(stack_length), + distance_(distance), + fdistance_(fdistance) {} + + bool operator()(StateId s1, StateId s2) const { + auto si1 = state_table_.Tuple(s1).stack_id; + auto si2 = state_table_.Tuple(s2).stack_id; + if (stack_length_[si1] < stack_length_[si2]) return true; + if (stack_length_[si1] > stack_length_[si2]) return false; + // If stack IDs are equal, use A*. + if (si1 == si2) { + return less_(Distance(s1), Distance(s2)); + } + // If lengths are equal, uses reverse lexicographic order. + for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { + if (stack_.Top(si1) < stack_.Top(si2)) return true; + if (stack_.Top(si1) > stack_.Top(si2)) return false; + } + return false; + } + + private: + Weight Distance(StateId s) const { + return (s < distance_.size()) && (s < fdistance_.size()) + ? Times(distance_[s], fdistance_[s]) + : Weight::Zero(); + } + + const StateTable &state_table_; + const Stack &stack_; + const std::vector &stack_length_; + const std::vector &distance_; + const std::vector &fdistance_; + const NaturalLess less_; + }; + + class ShortestStackFirstQueue + : public ShortestFirstQueue { + public: + ShortestStackFirstQueue(const PdtStateTable &state_table, + const Stack &stack, + const std::vector &stack_length, + const std::vector &distance, + const std::vector &fdistance) + : ShortestFirstQueue(StackCompare( + state_table, stack, stack_length, distance, fdistance)) {} + }; + + void InitCloseParenMultimap( + const std::vector> &parens); + + Weight DistanceToDest(StateId source, StateId dest) const; + + uint8 Flags(StateId s) const; + + void SetFlags(StateId s, uint8 flags, uint8 mask); + + Weight Distance(StateId s) const; + + void SetDistance(StateId s, Weight weight); + + Weight FinalDistance(StateId s) const; + + void SetFinalDistance(StateId s, Weight weight); + + StateId SourceState(StateId s) const; + + void SetSourceState(StateId s, StateId p); + + void AddStateAndEnqueue(StateId s); + + void Relax(StateId s, const Arc &arc, Weight weight); + + bool PruneArc(StateId s, const Arc &arc); + + void ProcStart(); + + void ProcFinal(StateId s); + + bool ProcNonParen(StateId s, const Arc &arc, bool add_arc); + + bool ProcOpenParen(StateId s, const Arc &arc, StackId si, StackId nsi); + + bool ProcCloseParen(StateId s, const Arc &arc); + + void ProcDestStates(StateId s, StackId si); + + // Input PDT. + std::unique_ptr> ifst_; + // Reversed PDT. + VectorFst rfst_; + // Keep parentheses in ofst? + const bool keep_parentheses_; + // State table for efst_. + StateTable state_table_; + // Stack trie. + Stack stack_; + // Expanded PDT. + PdtExpandFst efst_; + // Length of stack for given stack ID. + std::vector stack_length_; + // Distance from initial state in efst_/ofst. + std::vector distance_; + // Distance to final states in efst_/ofst. + std::vector fdistance_; + // Queue used to visit efst_. + ShortestStackFirstQueue queue_; + // Construction time failure? + bool error_; + // Status flags for states in efst_/ofst. + std::vector flags_; + // PDT source state for each expanded state. + std::vector sources_; + // Shortest path for rfst_. + std::unique_ptr>> + reverse_shortest_path_; + std::unique_ptr> balance_data_; + // Maps open paren arcs to balancing close paren arcs. + typename PdtShortestPath>::CloseParenMultimap + close_paren_multimap_; + MutableFst *ofst_; // Output FST. + Weight limit_; // Weight limit. + + // Maps a state s in ifst (i.e., the source of a closed paranthesis matching + // the top of current_stack_id_ to final states in efst_. + std::unordered_map dest_map_; + // Stack ID of the states currently at the top of the queue, i.e., the states + // currently being popped and processed. + StackId current_stack_id_; + ssize_t current_paren_id_; // Paren ID at top of current stack. + ssize_t cached_stack_id_; + StateId cached_source_; + // The set of pairs of destination states and weights to final states for the + // source state cached_source_ and the paren ID cached_paren_id_; i.e., the + // set of source states of a closed parenthesis with paren ID cached_paren_id + // balancing an incoming open parenthesis with paren ID cached_paren_id_ in + // state cached_source_. + std::forward_list> cached_dest_list_; + NaturalLess less_; +}; + +// Initializes close paren multimap, mapping pairs (s, paren_id) to all the arcs +// out of s labeled with close parenthese for paren_id. +template +void PdtPrunedExpand::InitCloseParenMultimap( + const std::vector> &parens) { + std::unordered_map paren_map; + for (size_t i = 0; i < parens.size(); ++i) { + const auto &pair = parens[i]; + paren_map[pair.first] = i; + paren_map[pair.second] = i; + } + for (StateIterator> siter(*ifst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + for (ArcIterator> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + const auto it = paren_map.find(arc.ilabel); + if (it == paren_map.end()) continue; + if (arc.ilabel == parens[it->second].second) { // Close paren. + const internal::ParenState key(it->second, s); + close_paren_multimap_.emplace(key, arc); + } + } + } +} + +// Returns the weight of the shortest balanced path from source to dest +// in ifst_; dest must be the source state of a close paren arc. +template +typename Arc::Weight PdtPrunedExpand::DistanceToDest(StateId source, + StateId dest) const { + using SearchState = + typename PdtShortestPath>::SearchState; + const SearchState ss(source + 1, dest + 1); + const auto distance = + reverse_shortest_path_->GetShortestPathData().Distance(ss); + VLOG(2) << "D(" << source << ", " << dest << ") =" << distance; + return distance; +} + +// Returns the flags for state s in ofst_. +template +uint8 PdtPrunedExpand::Flags(StateId s) const { + return s < flags_.size() ? flags_[s] : 0; +} + +// Modifies the flags for state s in ofst_. +template +void PdtPrunedExpand::SetFlags(StateId s, uint8 flags, uint8 mask) { + while (flags_.size() <= s) flags_.push_back(0); + flags_[s] &= ~mask; + flags_[s] |= flags & mask; +} + +// Returns the shortest distance from the initial state to s in ofst_. +template +typename Arc::Weight PdtPrunedExpand::Distance(StateId s) const { + return s < distance_.size() ? distance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from the initial state to s in ofst_. +template +void PdtPrunedExpand::SetDistance(StateId s, Weight weight) { + while (distance_.size() <= s) distance_.push_back(Weight::Zero()); + distance_[s] = std::move(weight); +} + +// Returns the shortest distance from s to the final states in ofst_. +template +typename Arc::Weight PdtPrunedExpand::FinalDistance(StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from s to the final states in ofst_. +template +void PdtPrunedExpand::SetFinalDistance(StateId s, Weight weight) { + while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); + fdistance_[s] = std::move(weight); +} + +// Returns the PDT source state of state s in ofst_. +template +typename Arc::StateId PdtPrunedExpand::SourceState(StateId s) const { + return s < sources_.size() ? sources_[s] : kNoStateId; +} + +// Sets the PDT source state of state s in ofst_ to state p'in ifst_. +template +void PdtPrunedExpand::SetSourceState(StateId s, StateId p) { + while (sources_.size() <= s) sources_.push_back(kNoStateId); + sources_[s] = p; +} + +// Adds state s of efst_ to ofst_ and inserts it in the queue, modifying the +// flags for s accordingly. +template +void PdtPrunedExpand::AddStateAndEnqueue(StateId s) { + if (!(Flags(s) & (kEnqueued | kExpanded))) { + while (ofst_->NumStates() <= s) ofst_->AddState(); + queue_.Enqueue(s); + SetFlags(s, kEnqueued, kEnqueued); + } else if (Flags(s) & kEnqueued) { + queue_.Update(s); + } + // TODO(allauzen): Check everything is fine when kExpanded? +} + +// Relaxes arc out of state s in ofst_ as follows: +// +// 1. If the distance to s times the weight of arc is smaller than +// the currently stored distance for arc.nextstate, updates +// Distance(arc.nextstate) with a new estimate +// 2. If fd is less than the currently stored distance from arc.nextstate to the +// final state, updates with new estimate. +template +void PdtPrunedExpand::Relax(StateId s, const Arc &arc, Weight fd) { + const auto nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) { + SetDistance(arc.nextstate, nd); + SetSourceState(arc.nextstate, SourceState(s)); + } + if (less_(fd, FinalDistance(arc.nextstate))) { + SetFinalDistance(arc.nextstate, fd); + } + VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " + << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) + << ", nd = " << nd; +} + +// Returns whether the arc out of state s in efst needs pruned. +template +bool PdtPrunedExpand::PruneArc(StateId s, const Arc &arc) { + VLOG(2) << "Prune ?"; + auto fd = Weight::Zero(); + if ((cached_source_ != SourceState(s)) || + (cached_stack_id_ != current_stack_id_)) { + cached_source_ = SourceState(s); + cached_stack_id_ = current_stack_id_; + cached_dest_list_.clear(); + if (cached_source_ != ifst_->Start()) { + for (auto set_iter = + balance_data_->Find(current_paren_id_, cached_source_); + !set_iter.Done(); set_iter.Next()) { + auto dest = set_iter.Element(); + const auto it = dest_map_.find(dest); + cached_dest_list_.push_front(*it); + } + } else { + // TODO(allauzen): queue discipline should prevent this from ever + // happening. + // Replace by a check. + cached_dest_list_.push_front( + std::make_pair(rfst_.Start() - 1, Weight::One())); + } + } + for (auto it = cached_dest_list_.begin(); it != cached_dest_list_.end(); + ++it) { + const auto d = + DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first); + fd = Plus(fd, Times(d, it->second)); + } + Relax(s, arc, fd); + return less_(limit_, Times(Distance(s), Times(arc.weight, fd))); +} + +// Adds start state of efst_ to ofst_, enqueues it, and initializes the distance +// data structures. +template +void PdtPrunedExpand::ProcStart() { + const auto s = efst_.Start(); + AddStateAndEnqueue(s); + ofst_->SetStart(s); + SetSourceState(s, ifst_->Start()); + current_stack_id_ = 0; + current_paren_id_ = -1; + stack_length_.push_back(0); + const auto r = rfst_.Start() - 1; + cached_source_ = ifst_->Start(); + cached_stack_id_ = 0; + cached_dest_list_.push_front(std::make_pair(r, Weight::One())); + const PdtStateTuple tuple(r, 0); + SetFinalDistance(state_table_.FindState(tuple), Weight::One()); + SetDistance(s, Weight::One()); + const auto d = DistanceToDest(ifst_->Start(), r); + SetFinalDistance(s, d); + VLOG(2) << d; +} + +// Makes s final in ofst_ if shortest accepting path ending in s is below +// threshold. +template +void PdtPrunedExpand::ProcFinal(StateId s) { + const auto weight = efst_.Final(s); + if (weight == Weight::Zero()) return; + if (less_(limit_, Times(Distance(s), weight))) return; + ofst_->SetFinal(s, weight); +} + +// Returns true when an arc (or meta-arc) leaving state s in efst_ is below the +// threshold. When add_arc is true, arc is added to ofst_. +template +bool PdtPrunedExpand::ProcNonParen(StateId s, const Arc &arc, + bool add_arc) { + VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate << ", " + << arc.ilabel << ":" << arc.olabel << " / " << arc.weight + << ", add_arc = " << (add_arc ? "true" : "false"); + if (PruneArc(s, arc)) return false; + if (add_arc) ofst_->AddArc(s, arc); + AddStateAndEnqueue(arc.nextstate); + return true; +} + +// Processes an open paren arc leaving state s in ofst_. When the arc is labeled +// with an open paren, +// +// 1. Considers each (shortest) balanced path starting in s by taking the arc +// and ending by a close paren balancing the open paren of as a meta-arc, +// processing and pruning each meta-arc as a non-paren arc, inserting its +// destination to the queue; +// 2. if at least one of these meta-arcs has not been pruned, adds the +// destination of arc to ofst_ as a new source state for the stack ID nsi, and +// inserts it in the queue. +template +bool PdtPrunedExpand::ProcOpenParen(StateId s, const Arc &arc, StackId si, + StackId nsi) { + // Updates the stack length when needed. + while (stack_length_.size() <= nsi) stack_length_.push_back(-1); + if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1; + const auto ns = arc.nextstate; + VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id + << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; + bool proc_arc = false; + auto fd = Weight::Zero(); + const auto paren_id = stack_.ParenId(arc.ilabel); + std::forward_list sources; + for (auto set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); + !set_iter.Done(); set_iter.Next()) { + sources.push_front(set_iter.Element()); + } + for (const auto source : sources) { + VLOG(2) << "Close paren source: " << source; + const internal::ParenState paren_state(paren_id, source); + for (auto it = close_paren_multimap_.find(paren_state); + it != close_paren_multimap_.end() && paren_state == it->first; ++it) { + auto meta_arc = it->second; + const PdtStateTuple tuple(meta_arc.nextstate, si); + meta_arc.nextstate = state_table_.FindState(tuple); + const auto state_id = state_table_.Tuple(ns).state_id; + const auto d = DistanceToDest(state_id, source); + VLOG(2) << state_id << ", " << source; + VLOG(2) << "Meta arc weight = " << arc.weight << " Times " << d + << " Times " << meta_arc.weight; + meta_arc.weight = Times(arc.weight, Times(d, meta_arc.weight)); + proc_arc |= ProcNonParen(s, meta_arc, false); + fd = Plus( + fd, + Times(Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), + it->second.weight), + FinalDistance(meta_arc.nextstate))); + } + } + if (proc_arc) { + VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + AddStateAndEnqueue(arc.nextstate); + const auto nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd); + // FinalDistance not necessary for source state since pruning decided using + // meta-arcs above. But this is a problem with A*, hence the following. + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + SetFlags(arc.nextstate, kSourceState, kSourceState); + } + return proc_arc; +} + +// Checks that shortest path through close paren arc in efst_ is below +// threshold, and if so, adds it to ofst_. +template +bool PdtPrunedExpand::ProcCloseParen(StateId s, const Arc &arc) { + const auto weight = + Times(Distance(s), Times(arc.weight, FinalDistance(arc.nextstate))); + if (less_(limit_, weight)) return false; + ofst_->AddArc(s, + keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + return true; +} + +// When state s in ofst_ is a source state for stack ID si, identifies all the +// corresponding possible destination states, that is, all the states in ifst_ +// that have an outgoing close paren arc balancing the incoming open paren taken +// to get to s. For each such state t, computes the shortest distance from (t, +// si) to the final states in ofst_. Stores this information in dest_map_. +template +void PdtPrunedExpand::ProcDestStates(StateId s, StackId si) { + if (!(Flags(s) & kSourceState)) return; + if (si != current_stack_id_) { + dest_map_.clear(); + current_stack_id_ = si; + current_paren_id_ = stack_.Top(current_stack_id_); + VLOG(2) << "StackID " << si << " dequeued for first time"; + } + // TODO(allauzen): clean up source state business; rename current function to + // ProcSourceState. + SetSourceState(s, state_table_.Tuple(s).state_id); + const auto paren_id = stack_.Top(si); + for (auto set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); + !set_iter.Done(); set_iter.Next()) { + const auto dest_state = set_iter.Element(); + if (dest_map_.find(dest_state) != dest_map_.end()) continue; + auto dest_weight = Weight::Zero(); + internal::ParenState paren_state(paren_id, dest_state); + for (auto it = close_paren_multimap_.find(paren_state); + it != close_paren_multimap_.end() && paren_state == it->first; ++it) { + const auto &arc = it->second; + const PdtStateTuple tuple(arc.nextstate, + stack_.Pop(si)); + dest_weight = + Plus(dest_weight, + Times(arc.weight, FinalDistance(state_table_.FindState(tuple)))); + } + dest_map_[dest_state] = dest_weight; + VLOG(2) << "State " << dest_state << " is a dest state for stack ID " << si + << " with weight " << dest_weight; + } +} + +// Expands and prunes the input PDT, writing the result in ofst. +template +void PdtPrunedExpand::Expand(MutableFst *ofst, + const typename Arc::Weight &threshold) { + ofst_ = ofst; + if (error_) { + ofst_->SetProperties(kError, kError); + return; + } + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst_->InputSymbols()); + ofst_->SetOutputSymbols(ifst_->OutputSymbols()); + limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); + flags_.clear(); + ProcStart(); + while (!queue_.Empty()) { + const auto s = queue_.Head(); + queue_.Dequeue(); + SetFlags(s, kExpanded, kExpanded | kEnqueued); + VLOG(2) << s << " dequeued!"; + ProcFinal(s); + StackId stack_id = state_table_.Tuple(s).stack_id; + ProcDestStates(s, stack_id); + for (ArcIterator> aiter(efst_, s); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); + const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; + if (stack_id == nextstack_id) { + ProcNonParen(s, arc, true); + } else if (stack_id == stack_.Pop(nextstack_id)) { + ProcOpenParen(s, arc, stack_id, nextstack_id); + } else { + ProcCloseParen(s, arc); + } + } + VLOG(2) << "d[" << s << "] = " << Distance(s) << ", fd[" << s + << "] = " << FinalDistance(s); + } +} + +// Expand functions. + +template +struct PdtExpandOptions { + using Weight = typename Arc::Weight; + + bool connect; + bool keep_parentheses; + Weight weight_threshold; + + PdtExpandOptions(bool connect = true, bool keep_parentheses = false, + Weight weight_threshold = Weight::Zero()) + : connect(connect), + keep_parentheses(keep_parentheses), + weight_threshold(std::move(weight_threshold)) {} +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version writes the expanded PDT to a mutable FST. In the PDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// a PDT, the parens must balance on a path. The open-close parenthesis label +// pairs are passed using the parens argument. Expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +template +void Expand( + const Fst &ifst, + const std::vector> + &parens, + MutableFst *ofst, const PdtExpandOptions &opts) { + PdtExpandFstOptions eopts; + eopts.gc_limit = 0; + if (opts.weight_threshold == Arc::Weight::Zero()) { + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = PdtExpandFst(ifst, parens, eopts); + } else { + PdtPrunedExpand pruned_expand(ifst, parens, opts.keep_parentheses); + pruned_expand.Expand(ofst, opts.weight_threshold); + } + if (opts.connect) Connect(ofst); +} + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version writes the expanded PDT result to a mutable FST. In the PDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// a PDT, the parens must balance on a path. The open-close parenthesis label +// pairs are passed using the parents argument. Expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector> + &parens, MutableFst *ofst, bool connect = true, + bool keep_parentheses = false) { + const PdtExpandOptions opts(connect, keep_parentheses); + Expand(ifst, parens, ofst, opts); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_EXPAND_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/getters.h b/projects/llm_framework/include/fst/extensions/pdt/getters.h new file mode 100644 index 00000000..69dd150d --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/getters.h @@ -0,0 +1,22 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_PDT_GETTERS_H_ +#define FST_EXTENSIONS_PDT_GETTERS_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +bool GetPdtComposeFilter(const string &str, PdtComposeFilter *cf); + +bool GetPdtParserType(const string &str, PdtParserType *pt); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/info.h b/projects/llm_framework/include/fst/extensions/pdt/info.h new file mode 100644 index 00000000..3de54772 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/info.h @@ -0,0 +1,152 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Prints information about a PDT. + +#ifndef FST_EXTENSIONS_PDT_INFO_H_ +#define FST_EXTENSIONS_PDT_INFO_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +// Compute various information about PDTs. +template +class PdtInfo { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + PdtInfo(const Fst &fst, + const std::vector> &parents); + + const string &FstType() const { return fst_type_; } + + const string &ArcType() const { return Arc::Type(); } + + int64 NumStates() const { return nstates_; } + + int64 NumArcs() const { return narcs_; } + + int64 NumOpenParens() const { return nopen_parens_; } + + int64 NumCloseParens() const { return nclose_parens_; } + + int64 NumUniqueOpenParens() const { return nuniq_open_parens_; } + + int64 NumUniqueCloseParens() const { return nuniq_close_parens_; } + + int64 NumOpenParenStates() const { return nopen_paren_states_; } + + int64 NumCloseParenStates() const { return nclose_paren_states_; } + + private: + string fst_type_; + int64 nstates_; + int64 narcs_; + int64 nopen_parens_; + int64 nclose_parens_; + int64 nuniq_open_parens_; + int64 nuniq_close_parens_; + int64 nopen_paren_states_; + int64 nclose_paren_states_; +}; + +template +PdtInfo::PdtInfo( + const Fst &fst, + const std::vector> + &parens) + : fst_type_(fst.Type()), + nstates_(0), + narcs_(0), + nopen_parens_(0), + nclose_parens_(0), + nuniq_open_parens_(0), + nuniq_close_parens_(0), + nopen_paren_states_(0), + nclose_paren_states_(0) { + std::unordered_map paren_map; + std::unordered_set> +class CompactFst; + +template +class ConstFst; + +template +class EditFst; + +template +class ExpandedFst; + +template +class Fst; + +template +class MutableFst; + +template > +class VectorState; + +template > +class VectorFst; + +template +class DefaultReplaceStateTable; + +// On-the-fly operations. + +template +class ArcSortFst; + +template +class ClosureFst; + +template > +class ComposeFst; + +template +class ConcatFst; + +template +class DeterminizeFst; + +template +class DifferenceFst; + +template +class IntersectFst; + +template +class InvertFst; + +template +class ArcMapFst; + +template +class ProjectFst; + +template +class RandGenFst; + +template +class RelabelFst; + +template , + class Store = DefaultCacheStore> +class ReplaceFst; + +template +class RmEpsilonFst; + +template +class UnionFst; + +// Heap. + +template +class Heap; + +// Compactors. + +template +class AcceptorCompactor; + +template +class StringCompactor; + +template +class UnweightedAcceptorCompactor; + +template +class UnweightedCompactor; + +template +class WeightedStringCompactor; + +// Compact FSTs. + +template +using CompactStringFst = CompactFst, U>; + +template +using CompactWeightedStringFst = + CompactFst, U>; + +template +using CompactAcceptorFst = CompactFst, U>; + +template +using CompactUnweightedFst = CompactFst, U>; + +template +using CompactUnweightedAcceptorFst = + CompactFst, U>; + +// StdArc aliases for FSTs. + +using StdConstFst = ConstFst; +using StdExpandedFst = ExpandedFst; +using StdFst = Fst; +using StdMutableFst = MutableFst; +using StdVectorFst = VectorFst; + +// StdArc aliases for on-the-fly operations. + +template +using StdArcSortFst = ArcSortFst; + +using StdClosureFst = ClosureFst; + +using StdComposeFst = ComposeFst; + +using StdConcatFst = ConcatFst; + +using StdDeterminizeFst = DeterminizeFst; + +using StdDifferenceFst = DifferenceFst; + +using StdIntersectFst = IntersectFst; + +using StdInvertFst = InvertFst; + +using StdProjectFst = ProjectFst; + +using StdRelabelFst = RelabelFst; + +using StdReplaceFst = ReplaceFst; + +using StdRmEpsilonFst = RmEpsilonFst; + +using StdUnionFst = UnionFst; + +// Filter states. + +template +class IntegerFilterState; + +using CharFilterState = IntegerFilterState; + +using ShortFilterState = IntegerFilterState; // NOLINT + +using IntFilterState = IntegerFilterState; + +// Matchers and filters. + +template +class Matcher; + +template +class NullComposeFilter; + +template +class TrivialComposeFilter; + +template +class SequenceComposeFilter; + +template +class AltSequenceComposeFilter; + +template +class MatchComposeFilter; + +template +class NoMatchComposeFilter; + +} // namespace fst + +#endif // FST_FST_DECL_H_ diff --git a/projects/llm_framework/include/fst/fst.h b/projects/llm_framework/include/fst/fst.h new file mode 100644 index 00000000..20e6bb3c --- /dev/null +++ b/projects/llm_framework/include/fst/fst.h @@ -0,0 +1,1007 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST abstract base class definition, state and arc iterator interface, and +// suggested base implementation. + +#ifndef FST_FST_H_ +#define FST_FST_H_ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + +DECLARE_bool(fst_align); + +namespace fst { + +bool IsFstHeader(std::istream &, const string &); + +class FstHeader; + +template +struct StateIteratorData; + +template +struct ArcIteratorData; + +template +class MatcherBase; + +struct FstReadOptions { + // FileReadMode(s) are advisory, there are many conditions than prevent a + // file from being mapped, READ mode will be selected in these cases with + // a warning indicating why it was chosen. + enum FileReadMode { READ, MAP }; + + string source; // Where you're reading from. + const FstHeader *header; // Pointer to FST header; if non-zero, use + // this info (don't read a stream header). + const SymbolTable *isymbols; // Pointer to input symbols; if non-zero, use + // this info (read and skip stream isymbols) + const SymbolTable *osymbols; // Pointer to output symbols; if non-zero, use + // this info (read and skip stream osymbols) + FileReadMode mode; // Read or map files (advisory, if possible) + bool read_isymbols; // Read isymbols, if any (default: true). + bool read_osymbols; // Read osymbols, if any (default: true). + + explicit FstReadOptions(const string &source = "", + const FstHeader *header = nullptr, + const SymbolTable *isymbols = nullptr, + const SymbolTable *osymbols = nullptr); + + explicit FstReadOptions(const string &source, const SymbolTable *isymbols, + const SymbolTable *osymbols = nullptr); + + // Helper function to convert strings FileReadModes into their enum value. + static FileReadMode ReadMode(const string &mode); + + // Outputs a debug string for the FstReadOptions object. + string DebugString() const; +}; + +struct FstWriteOptions { + string source; // Where you're writing to. + bool write_header; // Write the header? + bool write_isymbols; // Write input symbols? + bool write_osymbols; // Write output symbols? + bool align; // Write data aligned (may fail on pipes)? + bool stream_write; // Avoid seek operations in writing. + + explicit FstWriteOptions(const string &source = "", + bool write_header = true, bool write_isymbols = true, + bool write_osymbols = true, + bool align = FLAGS_fst_align, + bool stream_write = false) + : source(source), + write_header(write_header), + write_isymbols(write_isymbols), + write_osymbols(write_osymbols), + align(align), + stream_write(stream_write) {} +}; + +// Header class. +// +// This is the recommended file header representation. + +class FstHeader { + public: + enum { + HAS_ISYMBOLS = 0x1, // Has input symbol table. + HAS_OSYMBOLS = 0x2, // Has output symbol table. + IS_ALIGNED = 0x4, // Memory-aligned (where appropriate). + } Flags; + + FstHeader() : version_(0), flags_(0), properties_(0), start_(-1), + numstates_(0), numarcs_(0) {} + + const string &FstType() const { return fsttype_; } + + const string &ArcType() const { return arctype_; } + + int32 Version() const { return version_; } + + int32 GetFlags() const { return flags_; } + + uint64 Properties() const { return properties_; } + + int64 Start() const { return start_; } + + int64 NumStates() const { return numstates_; } + + int64 NumArcs() const { return numarcs_; } + + void SetFstType(const string &type) { fsttype_ = type; } + + void SetArcType(const string &type) { arctype_ = type; } + + void SetVersion(int32 version) { version_ = version; } + + void SetFlags(int32 flags) { flags_ = flags; } + + void SetProperties(uint64 properties) { properties_ = properties; } + + void SetStart(int64 start) { start_ = start; } + + void SetNumStates(int64 numstates) { numstates_ = numstates; } + + void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; } + + bool Read(std::istream &strm, const string &source, + bool rewind = false); + + bool Write(std::ostream &strm, const string &source) const; + + // Outputs a debug string for the FstHeader object. + string DebugString() const; + + private: + string fsttype_; // E.g. "vector". + string arctype_; // E.g. "standard". + int32 version_; // Type version number. + int32 flags_; // File format bits. + uint64 properties_; // FST property bits. + int64 start_; // Start state. + int64 numstates_; // # of states. + int64 numarcs_; // # of arcs. +}; + +// Specifies matcher action. +enum MatchType { + MATCH_INPUT = 1, // Match input label. + MATCH_OUTPUT = 2, // Match output label. + MATCH_BOTH = 3, // Match input or output label. + MATCH_NONE = 4, // Match nothing. + MATCH_UNKNOWN = 5 +}; // Otherwise, match type unknown. + +constexpr int kNoLabel = -1; // Not a valid label. +constexpr int kNoStateId = -1; // Not a valid state ID. + +// A generic FST, templated on the arc definition, with common-demoninator +// methods (use StateIterator and ArcIterator to iterate over its states and +// arcs). +template +class Fst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual ~Fst() {} + + // Initial state. + virtual StateId Start() const = 0; + + // State's final weight. + virtual Weight Final(StateId) const = 0; + + // State's arc count. + virtual size_t NumArcs(StateId) const = 0; + + // State's input epsilon count. + virtual size_t NumInputEpsilons(StateId) const = 0; + + // State's output epsilon count. + virtual size_t NumOutputEpsilons(StateId) const = 0; + + // Property bits. If test = false, return stored properties bits for mask + // (some possibly unknown); if test = true, return property bits for mask + // (computing o.w. unknown). + virtual uint64 Properties(uint64 mask, bool test) const = 0; + + // FST type name. + virtual const string &Type() const = 0; + + // Gets a copy of this Fst. The copying behaves as follows: + // + // (1) The copying is constant time if safe = false or if safe = true + // and is on an otherwise unaccessed FST. + // + // (2) If safe = true, the copy is thread-safe in that the original + // and copy can be safely accessed (but not necessarily mutated) by + // separate threads. For some FST types, 'Copy(true)' should only be + // called on an FST that has not otherwise been accessed. Behavior is + // otherwise undefined. + // + // (3) If a MutableFst is copied and then mutated, then the original is + // unmodified and vice versa (often by a copy-on-write on the initial + // mutation, which may not be constant time). + virtual Fst *Copy(bool safe = false) const = 0; + + // Reads an FST from an input stream; returns nullptr on error. + static Fst *Read(std::istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + const auto &fst_type = hdr.FstType(); + const auto reader = FstRegister::GetRegister()->GetReader(fst_type); + if (!reader) { + LOG(ERROR) << "Fst::Read: Unknown FST type " << fst_type + << " (arc type = " << Arc::Type() << "): " << ropts.source; + return nullptr; + } + return reader(strm, ropts); + } + + // Reads an FST from a file; returns nullptr on error. An empty filename + // results in reading from standard input. + static Fst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "Fst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + // Writes an FST to an output stream; returns false on error. + virtual bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + LOG(ERROR) << "Fst::Write: No write stream method for " << Type() + << " FST type"; + return false; + } + + // Writes an FST to a file; returns false on error; an empty filename + // results in writing to standard output. + virtual bool Write(const string &filename) const { + LOG(ERROR) << "Fst::Write: No write filename method for " << Type() + << " FST type"; + return false; + } + + // Returns input label symbol table; return nullptr if not specified. + virtual const SymbolTable *InputSymbols() const = 0; + + // Return output label symbol table; return nullptr if not specified. + virtual const SymbolTable *OutputSymbols() const = 0; + + // For generic state iterator construction (not normally called directly by + // users). Does not copy the FST. + virtual void InitStateIterator(StateIteratorData *data) const = 0; + + // For generic arc iterator construction (not normally called directly by + // users). Does not copy the FST. + virtual void InitArcIterator(StateId s, ArcIteratorData *data) const = 0; + + // For generic matcher construction (not normally called directly by users). + // Does not copy the FST. + virtual MatcherBase *InitMatcher(MatchType match_type) const; + + protected: + bool WriteFile(const string &filename) const { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "Fst::Write: Can't open file: " << filename; + return false; + } + bool val = Write(strm, FstWriteOptions(filename)); + if (!val) LOG(ERROR) << "Fst::Write failed: " << filename; + return val; + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } +}; + +// A useful alias when using StdArc. +using StdFst = Fst; + +// State and arc iterator definitions. +// +// State iterator interface templated on the Arc definition; used for +// StateIterator specializations returned by the InitStateIterator FST method. +template +class StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + virtual ~StateIteratorBase() {} + + // End of iterator? + virtual bool Done() const = 0; + // Returns current state (when !Done()). + virtual StateId Value() const = 0; + // Advances to next state (when !Done()). + virtual void Next() = 0; + // Resets to initial condition. + virtual void Reset() = 0; +}; + +// StateIterator initialization data. + +template +struct StateIteratorData { + using StateId = typename Arc::StateId; + + // Specialized iterator if non-zero. + StateIteratorBase *base; + // Otherwise, the total number of states. + StateId nstates; + + StateIteratorData() : base(nullptr), nstates(0) {} + + StateIteratorData(const StateIteratorData &) = delete; + StateIteratorData &operator=(const StateIteratorData &) = delete; +}; + +// Generic state iterator, templated on the FST definition (a wrapper +// around a pointer to a specific one). Here is a typical use: +// +// for (StateIterator siter(fst); +// !siter.Done(); +// siter.Next()) { +// StateId s = siter.Value(); +// ... +// } +// There is no copying of the FST. +template +class StateIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + explicit StateIterator(const FST &fst) : s_(0) { + fst.InitStateIterator(&data_); + } + + ~StateIterator() { delete data_.base; } + + bool Done() const { + return data_.base ? data_.base->Done() : s_ >= data_.nstates; + } + + StateId Value() const { return data_.base ? data_.base->Value() : s_; } + + void Next() { + if (data_.base) { + data_.base->Next(); + } else { + ++s_; + } + } + + void Reset() { + if (data_.base) { + data_.base->Reset(); + } else { + s_ = 0; + } + } + + private: + StateIteratorData data_; + StateId s_; +}; + +// Flags to control the behavior on an arc iterator. +static constexpr uint32 kArcILabelValue = + 0x0001; // Value() gives valid ilabel. +static constexpr uint32 kArcOLabelValue = 0x0002; // " " " olabel. +static constexpr uint32 kArcWeightValue = 0x0004; // " " " weight. +static constexpr uint32 kArcNextStateValue = + 0x0008; // " " " nextstate. +static constexpr uint32 kArcNoCache = 0x0010; // No need to cache arcs. + +static constexpr uint32 kArcValueFlags = + kArcILabelValue | kArcOLabelValue | kArcWeightValue | kArcNextStateValue; + +static constexpr uint32 kArcFlags = kArcValueFlags | kArcNoCache; + +// Arc iterator interface, templated on the arc definition; used for arc +// iterator specializations that are returned by the InitArcIterator FST method. +template +class ArcIteratorBase { + public: + using StateId = typename Arc::StateId; + + virtual ~ArcIteratorBase() {} + + // End of iterator? + virtual bool Done() const = 0; + // Returns current arc (when !Done()). + virtual const Arc &Value() const = 0; + // Advances to next arc (when !Done()). + virtual void Next() = 0; + // Returns current position. + virtual size_t Position() const = 0; + // Returns to initial condition. + virtual void Reset() = 0; + // Advances to arbitrary arc by position. + virtual void Seek(size_t) = 0; + // Returns current behavorial flags + virtual uint32 Flags() const = 0; + // Sets behavorial flags. + virtual void SetFlags(uint32, uint32) = 0; +}; + +// ArcIterator initialization data. +template +struct ArcIteratorData { + ArcIteratorData() + : base(nullptr), arcs(nullptr), narcs(0), ref_count(nullptr) {} + + ArcIteratorData(const ArcIteratorData &) = delete; + + ArcIteratorData &operator=(const ArcIteratorData &) = delete; + + ArcIteratorBase *base; // Specialized iterator if non-zero. + const Arc *arcs; // O.w. arcs pointer + size_t narcs; // ... and arc count. + int *ref_count; // ... and reference count if non-zero. +}; + +// Generic arc iterator, templated on the FST definition (a wrapper around a +// pointer to a specific one). Here is a typical use: +// +// for (ArcIterator aiter(fst, s); +// !aiter.Done(); +// aiter.Next()) { +// StdArc &arc = aiter.Value(); +// ... +// } +// There is no copying of the FST. +template +class ArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + ArcIterator(const FST &fst, StateId s) : i_(0) { + fst.InitArcIterator(s, &data_); + } + + explicit ArcIterator(const ArcIteratorData &data) : data_(data), i_(0) { + if (data_.ref_count) ++(*data_.ref_count); + } + + ~ArcIterator() { + if (data_.base) { + delete data_.base; + } else if (data_.ref_count) { + --(*data_.ref_count); + } + } + + bool Done() const { + return data_.base ? data_.base->Done() : i_ >= data_.narcs; + } + + const Arc &Value() const { + return data_.base ? data_.base->Value() : data_.arcs[i_]; + } + + void Next() { + if (data_.base) { + data_.base->Next(); + } else { + ++i_; + } + } + + void Reset() { + if (data_.base) { + data_.base->Reset(); + } else { + i_ = 0; + } + } + + void Seek(size_t a) { + if (data_.base) { + data_.base->Seek(a); + } else { + i_ = a; + } + } + + size_t Position() const { return data_.base ? data_.base->Position() : i_; } + + uint32 Flags() const { + if (data_.base) { + return data_.base->Flags(); + } else { + return kArcValueFlags; + } + } + + void SetFlags(uint32 flags, uint32 mask) { + if (data_.base) data_.base->SetFlags(flags, mask); + } + + private: + ArcIteratorData data_; + size_t i_; +}; + +} // namespace fst + +// ArcIterator placement operator new and destroy function; new needs to be in +// the global namespace. + +template +void *operator new(size_t size, + fst::MemoryPool> *pool) { + return pool->Allocate(); +} + +namespace fst { + +template +void Destroy(ArcIterator *aiter, MemoryPool> *pool) { + if (aiter) { + aiter->~ArcIterator(); + pool->Free(aiter); + } +} + +// Matcher definitions. + +template +MatcherBase *Fst::InitMatcher(MatchType match_type) const { + return nullptr; // One should just use the default matcher. +} + +// FST accessors, useful in high-performance applications. + +namespace internal { + +// General case, requires non-abstract, 'final' methods. Use for inlining. + +template +inline typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) { + return fst.F::Final(s); +} + +template +inline ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumOutputEpsilons(s); +} + +// Fst case, abstract methods. + +template +inline typename Arc::Weight Final(const Fst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline size_t NumArcs(const Fst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline size_t NumInputEpsilons(const Fst &fst, typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline size_t NumOutputEpsilons(const Fst &fst, typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +// FST implementation base. +// +// This is the recommended FST implementation base class. It will handle +// reference counts, property bits, type information and symbols. +// +// Users are discouraged, but not prohibited, from subclassing this outside the +// FST library. +template +class FstImpl { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + FstImpl() : properties_(0), type_("null") {} + + FstImpl(const FstImpl &impl) + : properties_(impl.properties_), + type_(impl.type_), + isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : nullptr), + osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : nullptr) {} + + FstImpl(FstImpl &&impl) noexcept; + + virtual ~FstImpl() {} + + FstImpl &operator=(const FstImpl &impl) { + properties_ = impl.properties_; + type_ = impl.type_; + isymbols_ = impl.isymbols_ ? impl.isymbols_->Copy() : nullptr; + osymbols_ = impl.osymbols_ ? impl.osymbols_->Copy() : nullptr; + return *this; + } + + FstImpl &operator=(FstImpl &&impl) noexcept; + + const string &Type() const { return type_; } + + void SetType(const string &type) { type_ = type; } + + virtual uint64 Properties() const { return properties_; } + + virtual uint64 Properties(uint64 mask) const { return properties_ & mask; } + + void SetProperties(uint64 props) { + properties_ &= kError; // kError can't be cleared. + properties_ |= props; + } + + void SetProperties(uint64 props, uint64 mask) { + properties_ &= ~mask | kError; // kError can't be cleared. + properties_ |= props & mask; + } + + // Allows (only) setting error bit on const FST implementations. + void SetProperties(uint64 props, uint64 mask) const { + if (mask != kError) { + FSTERROR() << "FstImpl::SetProperties() const: Can only set kError"; + } + properties_ |= kError; + } + + const SymbolTable *InputSymbols() const { return isymbols_.get(); } + + const SymbolTable *OutputSymbols() const { return osymbols_.get(); } + + SymbolTable *InputSymbols() { return isymbols_.get(); } + + SymbolTable *OutputSymbols() { return osymbols_.get(); } + + void SetInputSymbols(const SymbolTable *isyms) { + isymbols_.reset(isyms ? isyms->Copy() : nullptr); + } + + void SetOutputSymbols(const SymbolTable *osyms) { + osymbols_.reset(osyms ? osyms->Copy() : nullptr); + } + + // Reads header and symbols from input stream, initializes FST, and returns + // the header. If opts.header is non-null, skips reading and uses the option + // value instead. If opts.[io]symbols is non-null, reads in (if present), but + // uses the option value. + bool ReadHeader(std::istream &strm, const FstReadOptions &opts, + int min_version, FstHeader *hdr); + + // Writes header and symbols to output stream. If opts.header is false, skips + // writing header. If opts.[io]symbols is false, skips writing those symbols. + // This method is needed for implementations that implement Write methods. + void WriteHeader(std::ostream &strm, const FstWriteOptions &opts, + int version, FstHeader *hdr) const { + if (opts.write_header) { + hdr->SetFstType(type_); + hdr->SetArcType(Arc::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties_); + int32 file_flags = 0; + if (isymbols_ && opts.write_isymbols) { + file_flags |= FstHeader::HAS_ISYMBOLS; + } + if (osymbols_ && opts.write_osymbols) { + file_flags |= FstHeader::HAS_OSYMBOLS; + } + if (opts.align) file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm); + if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm); + } + + // Writes out header and symbols to output stream. If opts.header is false, + // skips writing header. If opts.[io]symbols is false, skips writing those + // symbols. `type` is the FST type being written. This method is used in the + // cross-type serialization methods Fst::WriteFst. + static void WriteFstHeader(const Fst &fst, std::ostream &strm, + const FstWriteOptions &opts, int version, + const string &type, uint64 properties, + FstHeader *hdr) { + if (opts.write_header) { + hdr->SetFstType(type); + hdr->SetArcType(Arc::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties); + int32 file_flags = 0; + if (fst.InputSymbols() && opts.write_isymbols) { + file_flags |= FstHeader::HAS_ISYMBOLS; + } + if (fst.OutputSymbols() && opts.write_osymbols) { + file_flags |= FstHeader::HAS_OSYMBOLS; + } + if (opts.align) file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (fst.InputSymbols() && opts.write_isymbols) { + fst.InputSymbols()->Write(strm); + } + if (fst.OutputSymbols() && opts.write_osymbols) { + fst.OutputSymbols()->Write(strm); + } + } + + // In serialization routines where the header cannot be written until after + // the machine has been serialized, this routine can be called to seek to the + // beginning of the file an rewrite the header with updated fields. It + // repositions the file pointer back at the end of the file. Returns true on + // success, false on failure. + static bool UpdateFstHeader(const Fst &fst, std::ostream &strm, + const FstWriteOptions &opts, int version, + const string &type, uint64 properties, + FstHeader *hdr, size_t header_offset) { + strm.seekp(header_offset); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + WriteFstHeader(fst, strm, opts, version, type, properties, hdr); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + strm.seekp(0, std::ios_base::end); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + return true; + } + + protected: + mutable uint64 properties_; // Property bits. + + private: + string type_; // Unique name of FST class. + std::unique_ptr isymbols_; + std::unique_ptr osymbols_; +}; + +template +inline FstImpl::FstImpl(FstImpl &&) noexcept = default; + +template +inline FstImpl &FstImpl::operator=( + FstImpl &&) noexcept = default; + +template +bool FstImpl::ReadHeader(std::istream &strm, const FstReadOptions &opts, + int min_version, FstHeader *hdr) { + if (opts.header) { + *hdr = *opts.header; + } else if (!hdr->Read(strm, opts.source)) { + return false; + } + if (FLAGS_v >= 2) { + LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source + << ", fst_type: " << hdr->FstType() + << ", arc_type: " << Arc::Type() + << ", version: " << hdr->Version() + << ", flags: " << hdr->GetFlags(); + } + if (hdr->FstType() != type_) { + LOG(ERROR) << "FstImpl::ReadHeader: FST not of type " << type_ + << ": " << opts.source; + return false; + } + if (hdr->ArcType() != Arc::Type()) { + LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type " << Arc::Type() + << ": " << opts.source; + return false; + } + if (hdr->Version() < min_version) { + LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_ + << " FST version: " << opts.source; + return false; + } + properties_ = hdr->Properties(); + if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) { + isymbols_.reset(SymbolTable::Read(strm, opts.source)); + } + // Deletes input symbol table. + if (!opts.read_isymbols) SetInputSymbols(nullptr); + if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) { + osymbols_.reset(SymbolTable::Read(strm, opts.source)); + } + // Deletes output symbol table. + if (!opts.read_osymbols) SetOutputSymbols(nullptr); + if (opts.isymbols) { + isymbols_.reset(opts.isymbols->Copy()); + } + if (opts.osymbols) { + osymbols_.reset(opts.osymbols->Copy()); + } + return true; +} + +} // namespace internal + +template +uint64 TestProperties(const Fst &fst, uint64 mask, uint64 *known); + +// This is a helper class template useful for attaching an FST interface to +// its implementation, handling reference counting. +template > +class ImplToFst : public FST { + public: + using Arc = typename Impl::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateId Start() const override { return impl_->Start(); } + + Weight Final(StateId s) const override { return impl_->Final(s); } + + size_t NumArcs(StateId s) const override { return impl_->NumArcs(s); } + + size_t NumInputEpsilons(StateId s) const override { + return impl_->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) const override { + return impl_->NumOutputEpsilons(s); + } + + uint64 Properties(uint64 mask, bool test) const override { + if (test) { + uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops); + impl_->SetProperties(testprops, knownprops); + return testprops & mask; + } else { + return impl_->Properties(mask); + } + } + + const string &Type() const override { return impl_->Type(); } + + const SymbolTable *InputSymbols() const override { + return impl_->InputSymbols(); + } + + const SymbolTable *OutputSymbols() const override { + return impl_->OutputSymbols(); + } + + protected: + explicit ImplToFst(std::shared_ptr impl) : impl_(std::move(impl)) {} + + // This constructor presumes there is a copy constructor for the + // implementation. + ImplToFst(const ImplToFst &fst, bool safe) { + if (safe) { + impl_ = std::make_shared(*(fst.impl_)); + } else { + impl_ = fst.impl_; + } + } + + ImplToFst() = delete; + + ImplToFst(const ImplToFst &fst) : impl_(fst.impl_) {} + + ImplToFst(ImplToFst &&fst) noexcept + : impl_(std::move(fst.impl_)) { + fst.impl_ = std::make_shared(); + } + + ImplToFst &operator=(const ImplToFst &fst) { + impl_ = fst.impl_; + return *this; + } + + ImplToFst &operator=(ImplToFst &&fst) noexcept { + if (this != &fst) { + impl_ = std::move(fst.impl_); + fst.impl_ = std::make_shared(); + } + return *this; + } + + // Returns raw pointers to the shared object. + const Impl *GetImpl() const { return impl_.get(); } + + Impl *GetMutableImpl() const { return impl_.get(); } + + // Returns a ref-counted smart poiner to the implementation. + std::shared_ptr GetSharedImpl() const { return impl_; } + + bool Unique() const { return impl_.unique(); } + + void SetImpl(std::shared_ptr impl) { impl_ = std::move(impl); } + + private: + template + friend void Cast(const IFST &ifst, OFST *ofst); + + std::shared_ptr impl_; +}; + +// Converts FSTs by casting their implementations, where this makes sense +// (which excludes implementations with weight-dependent virtual methods). +// Must be a friend of the FST classes involved (currently the concrete FSTs: +// ConstFst, CompactFst, and VectorFst). This can only be safely used for arc +// types that have identical storage characteristics. As with an FST +// copy constructor and Copy() method, this is a constant time operation +// (but subject to copy-on-write if it is a MutableFst and modified). +template +void Cast(const IFST &ifst, OFST *ofst) { + using OImpl = typename OFST::Impl; + ofst->impl_ = std::shared_ptr(ifst.impl_, + reinterpret_cast(ifst.impl_.get())); +} + +// FST serialization. + +template +string FstToString(const Fst &fst, + const FstWriteOptions &options = + FstWriteOptions("FstToString")) { + std::ostringstream ostrm; + fst.Write(ostrm, options); + return ostrm.str(); +} + +template +void FstToString(const Fst &fst, string *result) { + *result = FstToString(fst); +} + +template +void FstToString(const Fst &fst, string *result, + const FstWriteOptions &options) { + *result = FstToString(fst, options); +} + +template +Fst *StringToFst(const string &s) { + std::istringstream istrm(s); + return Fst::Read(istrm, FstReadOptions("StringToFst")); +} + +} // namespace fst + +#endif // FST_FST_H_ diff --git a/projects/llm_framework/include/fst/fstlib.h b/projects/llm_framework/include/fst/fstlib.h new file mode 100644 index 00000000..e8b1c3a1 --- /dev/null +++ b/projects/llm_framework/include/fst/fstlib.h @@ -0,0 +1,130 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This is a library for constructing, combining, optimizing, and searching +// "weighted finite-state transducers" (FSTs). Weighted finite-state transducers +// are automata where each transition has an input label, an output label, and a +// weight. The more familiar finite-state acceptor is represented as a +// transducer with each transition's input and output the same. Finite-state +// acceptors are used to represent sets of strings (specifically, "regular" or +// "rational sets"); finite-state transducers are used to represent binary +// relations between pairs of strings (specifically, "rational transductions"). +// The weights can be used to represent the cost of taking a particular +// transition. +// +// In this library, transducers are templated on the Arc (transition) +// definition, which allows changing the label, weight, and state ID sets. +// Labels and state IDs are restricted to signed integral types but the weight +// can be an arbitrary type whose members satisfy certain algebraic ("semiring") +// properties. +// +// This convenience file includes all other FST header files. + +#ifndef FST_FSTLIB_H_ +#define FST_FSTLIB_H_ + + +// Abstract FST classes. +#include +#include +#include + +// Concrete FST classes. +#include +#include +#include +#include + +// FST algorithms and delayed FST classes. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Weights. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Auxiliary classes for composition. +#include +#include +#include +#include +#include +#include + +// Data structures. +#include +#include +#include +#include + +// Miscellaneous. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#endif // FST_FSTLIB_H_ diff --git a/projects/llm_framework/include/fst/generic-register.h b/projects/llm_framework/include/fst/generic-register.h new file mode 100644 index 00000000..ea6b8fe1 --- /dev/null +++ b/projects/llm_framework/include/fst/generic-register.h @@ -0,0 +1,126 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_GENERIC_REGISTER_H_ +#define FST_GENERIC_REGISTER_H_ + +#include +#ifndef FST_NO_DYNAMIC_LINKING +#include +#endif +#include +#include + +#include +#include + +// Generic class representing a globally-stored correspondence between +// objects of KeyType and EntryType. +// +// KeyType must: +// +// * be such as can be stored as a key in a std::map<>. +// * be concatenable with a const char* with the + operator +// (or you must subclass and redefine LoadEntryFromSharedObject) +// +// EntryType must be default constructible. +// +// The third template parameter should be the type of a subclass of this class +// (think CRTP). This is to allow GetRegister() to instantiate and return an +// object of the appropriate type. + +namespace fst { + +template +class GenericRegister { + public: + using Key = KeyType; + using Entry = EntryType; + + static RegisterType *GetRegister() { + static auto reg = new RegisterType; + return reg; + } + + void SetEntry(const KeyType &key, const EntryType &entry) { + MutexLock l(®ister_lock_); + register_table_.insert(std::make_pair(key, entry)); + } + + EntryType GetEntry(const KeyType &key) const { + const auto *entry = LookupEntry(key); + if (entry) { + return *entry; + } else { + return LoadEntryFromSharedObject(key); + } + } + + virtual ~GenericRegister() {} + + protected: + // Override this if you want to be able to load missing definitions from + // shared object files. + virtual EntryType LoadEntryFromSharedObject(const KeyType &key) const { +#ifdef FST_NO_DYNAMIC_LINKING + return EntryType(); +#else + const auto so_filename = ConvertKeyToSoFilename(key); + void *handle = dlopen(so_filename.c_str(), RTLD_LAZY); + if (handle == nullptr) { + LOG(ERROR) << "GenericRegister::GetEntry: " << dlerror(); + return EntryType(); + } +#ifdef RUN_MODULE_INITIALIZERS + RUN_MODULE_INITIALIZERS(); +#endif + // We assume that the DSO constructs a static object in its global scope + // that does the registration. Thus we need only load it, not call any + // methods. + const auto *entry = this->LookupEntry(key); + if (entry == nullptr) { + LOG(ERROR) << "GenericRegister::GetEntry: " + << "lookup failed in shared object: " << so_filename; + return EntryType(); + } + return *entry; +#endif // FST_NO_DYNAMIC_LINKING + } + + // Override this to define how to turn a key into an SO filename. + virtual string ConvertKeyToSoFilename(const KeyType &key) const = 0; + + virtual const EntryType *LookupEntry(const KeyType &key) const { + MutexLock l(®ister_lock_); + const auto it = register_table_.find(key); + if (it != register_table_.end()) { + return &it->second; + } else { + return nullptr; + } + } + + private: + mutable Mutex register_lock_; + std::map register_table_; +}; + +// Generic register-er class capable of creating new register entries in the +// given RegisterType template parameter. This type must define types Key and +// Entry, and have appropriate static GetRegister() and instance SetEntry() +// functions. An easy way to accomplish this is to have RegisterType be the +// type of a subclass of GenericRegister. +template +class GenericRegisterer { + public: + using Key = typename RegisterType::Key; + using Entry = typename RegisterType::Entry; + + GenericRegisterer(Key key, Entry entry) { + RegisterType::GetRegister()->SetEntry(key, entry); + } +}; + +} // namespace fst + +#endif // FST_GENERIC_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/heap.h b/projects/llm_framework/include/fst/heap.h new file mode 100644 index 00000000..041a4bb9 --- /dev/null +++ b/projects/llm_framework/include/fst/heap.h @@ -0,0 +1,168 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Implementation of a heap as in STL, but allows tracking positions in heap +// using a key. The key can be used to do an in-place update of values in the +// heap. + +#ifndef FST_HEAP_H_ +#define FST_HEAP_H_ + +#include +#include + +#include +namespace fst { + +// A templated heap implementation that supports in-place update of values. +// +// The templated heap implementation is a little different from the STL +// priority_queue and the *_heap operations in STL. This heap supports +// indexing of values in the heap via an associated key. +// +// Each value is internally associated with a key which is returned to the +// calling functions on heap insert. This key can be used to later update +// the specific value in the heap. +// +// T: the element type of the hash. It can be POD, Data or a pointer to Data. +// Compare: comparison functor for determining min-heapness. +template +class Heap { + public: + using Value = T; + + static constexpr int kNoKey = -1; + + // Initializes with a specific comparator. + explicit Heap(Compare comp = Compare()) : comp_(comp), size_(0) {} + + // Inserts a value into the heap. + int Insert(const Value &value) { + if (size_ < values_.size()) { + values_[size_] = value; + pos_[key_[size_]] = size_; + } else { + values_.push_back(value); + pos_.push_back(size_); + key_.push_back(size_); + } + ++size_; + return Insert(value, size_ - 1); + } + + // Updates a value at position given by the key. The pos_ array is first + // indexed by the key. The position gives the position in the heap array. + // Once we have the position we can then use the standard heap operations + // to calculate the parent and child positions. + void Update(int key, const Value &value) { + const auto i = pos_[key]; + const bool is_better = comp_(value, values_[Parent(i)]); + values_[i] = value; + if (is_better) { + Insert(value, i); + } else { + Heapify(i); + } + } + + // Returns the least value. + Value Pop() { + Value top = values_.front(); + Swap(0, size_-1); + size_--; + Heapify(0); + return top; + } + + // Returns the least value w.r.t. the comparison function from the + // heap. + const Value &Top() const { return values_.front(); } + + // Returns the element for the given key. + const Value &Get(int key) const { return values_[pos_[key]]; } + + // Checks if the heap is empty. + bool Empty() const { return size_ == 0; } + + void Clear() { size_ = 0; } + + int Size() const { return size_; } + + void Reserve(int size) { + values_.reserve(size); + pos_.reserve(size); + key_.reserve(size); + } + + const Compare &GetCompare() const { return comp_; } + + private: + // The following private routines are used in a supportive role + // for managing the heap and keeping the heap properties. + + // Computes left child of parent. + static int Left(int i) { + return 2 * (i + 1) - 1; // 0 -> 1, 1 -> 3 + } + + // Computes right child of parent. + static int Right(int i) { + return 2 * (i + 1); // 0 -> 2, 1 -> 4 + } + + // Given a child computes parent. + static int Parent(int i) { + return (i - 1) / 2; // 0 -> 0, 1 -> 0, 2 -> 0, 3 -> 1, 4 -> 1, ... + } + + // Swaps a child and parent. Use to move element up/down tree. Note the use of + // a little trick here. When we swap we need to swap: + // + // - the value + // - the associated keys + // - the position of the value in the heap + void Swap(int j, int k) { + const auto tkey = key_[j]; + pos_[key_[j] = key_[k]] = j; + pos_[key_[k] = tkey] = k; + using std::swap; + swap(values_[j], values_[k]); + } + + // Heapifies the subtree rooted at index i. + void Heapify(int i) { + const auto l = Left(i); + const auto r = Right(i); + auto largest = (l < size_ && comp_(values_[l], values_[i])) ? l : i; + if (r < size_ && comp_(values_[r], values_[largest])) largest = r; + if (largest != i) { + Swap(i, largest); + Heapify(largest); + } + } + + // Inserts (updates) element at subtree rooted at index i. + int Insert(const Value &value, int i) { + int p; + while (i > 0 && !comp_(values_[p = Parent(i)], value)) { + Swap(i, p); + i = p; + } + return key_[i]; + } + + private: + const Compare comp_; + + std::vector pos_; + std::vector key_; + std::vector values_; + int size_; +}; + +template +constexpr int Heap::kNoKey; + +} // namespace fst + +#endif // FST_HEAP_H_ diff --git a/projects/llm_framework/include/fst/icu.h b/projects/llm_framework/include/fst/icu.h new file mode 100644 index 00000000..5459da24 --- /dev/null +++ b/projects/llm_framework/include/fst/icu.h @@ -0,0 +1,129 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This library implements an unrestricted Thompson/Pike UTF-8 parser and +// serializer. UTF-8 is a restricted subset of this byte stream encoding. For +// a description of the encoding details, see: +// +// http://en.wikipedia.org/wiki/UTF-8 + +#ifndef FST_ICU_H_ +#define FST_ICU_H_ + +#include +#include + +#include + +namespace fst { + +// Trivial function to copy bytestrings into vectors of labels, truncating +// if necessary. It is possible to use this sensibly with as little as 8 bits +// of Label precision. This returns `true` deterministically for compatibility. +template +bool ByteStringToLabels(const string &str, std::vector { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ComposeFst::CreateBase; + using ComposeFst::CreateBase1; + using ComposeFst::Properties; + + IntersectFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ComposeFst(CreateBase(fst1, fst2, opts)) { + const bool acceptors = + fst1.Properties(kAcceptor, true) && fst2.Properties(kAcceptor, true); + if (!acceptors) { + FSTERROR() << "IntersectFst: Input FSTs are not acceptors"; + GetMutableImpl()->SetProperties(kError); + } + } + + template + IntersectFst(const Fst &fst1, const Fst &fst2, + const IntersectFstOptions &opts) + : ComposeFst(CreateBase1(fst1, fst2, opts)) { + const bool acceptors = + fst1.Properties(kAcceptor, true) && fst2.Properties(kAcceptor, true); + if (!acceptors) { + FSTERROR() << "IntersectFst: input FSTs are not acceptors"; + GetMutableImpl()->SetProperties(kError); + } + } + + // See Fst<>::Copy() for doc. + IntersectFst(const IntersectFst &fst, bool safe = false) + : ComposeFst(fst, safe) {} + + // Get a copy of this IntersectFst. See Fst<>::Copy() for further doc. + IntersectFst *Copy(bool safe = false) const override { + return new IntersectFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for IntersectFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const IntersectFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for IntersectFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const IntersectFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdIntersectFst = IntersectFst; + +// Computes the intersection (Hadamard product) of two FSAs. This version +// writes the intersection to an output MurableFst. Only strings that are in +// both automata are retained in the result. +// +// The two arguments must be acceptors. One of the arguments must be +// label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template +void Intersect(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const IntersectOptions &opts = IntersectOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, nopts); + break; + } + case SEQUENCE_FILTER: { + IntersectFstOptions iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case ALT_SEQUENCE_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case MATCH_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case NO_MATCH_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case NULL_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case TRIVIAL_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_INTERSECT_H_ diff --git a/projects/llm_framework/include/fst/interval-set.h b/projects/llm_framework/include/fst/interval-set.h new file mode 100644 index 00000000..0942ea00 --- /dev/null +++ b/projects/llm_framework/include/fst/interval-set.h @@ -0,0 +1,398 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to represent and operate on sets of intervals. + +#ifndef FST_INTERVAL_SET_H_ +#define FST_INTERVAL_SET_H_ + +#include +#include +#include + + +#include + + +namespace fst { + +// Half-open integral interval [a, b) of signed integers of type T. +template +struct IntInterval { + T begin; + T end; + + IntInterval() : begin(-1), end(-1) {} + + IntInterval(T begin, T end) : begin(begin), end(end) {} + + bool operator<(const IntInterval &i) const { + return begin < i.begin || (begin == i.begin && end > i.end); + } + + bool operator==(const IntInterval &i) const { + return begin == i.begin && end == i.end; + } + + bool operator!=(const IntInterval &i) const { + return begin != i.begin || end != i.end; + } + + std::istream &Read(std::istream &strm) { + T n; + ReadType(strm, &n); + begin = n; + ReadType(strm, &n); + end = n; + return strm; + } + + std::ostream &Write(std::ostream &strm) const { + T n = begin; + WriteType(strm, n); + n = end; + WriteType(strm, n); + return strm; + } +}; + +// Stores IntIntervals in a vector. In addition, keeps the count of points in +// all intervals. +template +class VectorIntervalStore { + public: + using Interval = IntInterval; + using Iterator = typename std::vector::const_iterator; + + VectorIntervalStore() : count_(-1) {} + + std::vector *MutableIntervals() { return &intervals_; } + + const Interval *Intervals() const { return intervals_.data(); } + + T Size() const { return intervals_.size(); } + + T Count() const { return count_; } + + void SetCount(T count) { count_ = count; } + + void Clear() { + intervals_.clear(); + count_ = 0; + } + + Iterator begin() const { return intervals_.begin(); } + + Iterator end() const { return intervals_.end(); } + + std::istream &Read(std::istream &strm) { + ReadType(strm, &intervals_); + return ReadType(strm, &count_); + } + + std::ostream &Write(std::ostream &strm) const { + WriteType(strm, intervals_); + return WriteType(strm, count_); + } + + private: + std::vector intervals_; + T count_; +}; + +// Stores and operates on a set of half-open integral intervals [a, b) +// of signed integers of type T. +template > +class IntervalSet { + public: + using Interval = IntInterval; + + template + explicit IntervalSet(A... args) : intervals_(args...) {} + + // Returns the interval set as a vector. + std::vector *MutableIntervals() { + return intervals_.MutableIntervals(); + } + + // Returns a pointer to an array of Size() elements. + const Interval *Intervals() const { return intervals_.Intervals(); } + + bool Empty() const { return Size() == 0; } + + T Size() const { return intervals_.Size(); } + + // Number of points in the intervals (undefined if not normalized). + T Count() const { return intervals_.Count(); } + + void Clear() { intervals_.Clear(); } + + // Adds an interval set to the set. The result may not be normalized. + void Union(const IntervalSet &iset) { + intervals_.MutableIntervals()->insert(intervals_.MutableIntervals()->end(), + iset.intervals_.begin(), + iset.intervals_.end()); + } + + // Requires intervals be normalized. + bool Member(T value) const { + const Interval interval(value, value); + auto lb = std::lower_bound(intervals_.begin(), intervals_.end(), interval); + if (lb == intervals_.begin()) return false; + return (--lb)->end > value; + } + + // Requires intervals be normalized. + bool operator==(const IntervalSet &iset) const { + return Size() == iset.Size() && + std::equal(intervals_.begin(), intervals_.end(), + iset.intervals_.begin()); + } + + // Requires intervals be normalized. + bool operator!=(const IntervalSet &iset) const { + return Size() != iset.Size() || + !std::equal(intervals_.begin(), intervals_.end(), + iset.intervals_.begin()); + } + + bool Singleton() const { + return Size() == 1 && + intervals_.begin()->begin + 1 == intervals_.begin()->end; + } + + // Sorts, collapses overlapping and adjacent interals, and sets count. + void Normalize(); + + // Intersects an interval set with the set. Requires intervals be normalized. + // The result is normalized. + void Intersect(const IntervalSet &iset, + IntervalSet *oset) const; + + // Complements the set w.r.t [0, maxval). Requires intervals be normalized. + // The result is normalized. + void Complement(T maxval, IntervalSet *oset) const; + + // Subtract an interval set from the set. Requires intervals be normalized. + // The result is normalized. + void Difference(const IntervalSet &iset, + IntervalSet *oset) const; + + // Determines if an interval set overlaps with the set. Requires intervals be + // normalized. + bool Overlaps(const IntervalSet &iset) const; + + // Determines if an interval set overlaps with the set but neither is + // contained in the other. Requires intervals be normalized. + bool StrictlyOverlaps(const IntervalSet &iset) const; + + // Determines if an interval set is contained within the set. Requires + // intervals be normalized. + bool Contains(const IntervalSet &iset) const; + + std::istream &Read(std::istream &strm) { return intervals_.Read(strm); } + + std::ostream &Write(std::ostream &strm) const { + return intervals_.Write(strm); + } + + typename Store::Iterator begin() const { return intervals_.begin(); } + + typename Store::Iterator end() const { return intervals_.end(); } + + private: + Store intervals_; +}; + +// Sorts, collapses overlapping and adjacent intervals, and sets count. +template +void IntervalSet::Normalize() { + auto &intervals = *intervals_.MutableIntervals(); + std::sort(intervals.begin(), intervals.end()); + T count = 0; + T size = 0; + for (T i = 0; i < intervals.size(); ++i) { + auto &inti = intervals[i]; + if (inti.begin == inti.end) continue; + for (T j = i + 1; j < intervals.size(); ++j) { + auto &intj = intervals[j]; + if (intj.begin > inti.end) break; + if (intj.end > inti.end) inti.end = intj.end; + ++i; + } + count += inti.end - inti.begin; + intervals[size++] = inti; + } + intervals.resize(size); + intervals_.SetCount(count); +} + +// Intersects an interval set with the set. Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Intersect(const IntervalSet &iset, + IntervalSet *oset) const { + auto *ointervals = oset->MutableIntervals(); + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + ointervals->clear(); + T count = 0; + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { + ++it1; + } else if (it2->end <= it1->begin) { + ++it2; + } else { + ointervals->emplace_back(std::max(it1->begin, it2->begin), + std::min(it1->end, it2->end)); + count += ointervals->back().end - ointervals->back().begin; + if (it1->end < it2->end) { + ++it1; + } else { + ++it2; + } + } + } + oset->intervals_.SetCount(count); +} + +// Complements the set w.r.t [0, maxval). Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Complement(T maxval, + IntervalSet *oset) const { + auto *ointervals = oset->MutableIntervals(); + ointervals->clear(); + T count = 0; + Interval interval; + interval.begin = 0; + for (auto it = intervals_.begin(); it != intervals_.end(); ++it) { + interval.end = std::min(it->begin, maxval); + if ((interval.begin) < (interval.end)) { + ointervals->push_back(interval); + count += interval.end - interval.begin; + } + interval.begin = it->end; + } + interval.end = maxval; + if ((interval.begin) < (interval.end)) { + ointervals->push_back(interval); + count += interval.end - interval.begin; + } + oset->intervals_.SetCount(count); +} + +// Subtract an interval set from the set. Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Difference(const IntervalSet &iset, + IntervalSet *oset) const { + if (Empty()) { + oset->MutableIntervals()->clear(); + oset->intervals_.SetCount(0); + } else { + IntervalSet cset; + iset.Complement(intervals_.Intervals()[intervals_.Size() - 1].end, &cset); + Intersect(cset, oset); + } +} + +// Determines if an interval set overlaps with the set. Requires intervals be +// normalized. +template +bool IntervalSet::Overlaps(const IntervalSet &iset) const { + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { + ++it1; + } else if (it2->end <= it1->begin) { + ++it2; + } else { + return true; + } + } + return false; +} + +// Determines if an interval set overlaps with the set but neither is contained +// in the other. Requires intervals be normalized. +template +bool IntervalSet::StrictlyOverlaps( + const IntervalSet &iset) const { + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + bool only1 = false; // Point in intervals_ but not intervals. + bool only2 = false; // Point in intervals but not intervals_. + bool overlap = false; // Point in both intervals_ and intervals. + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { // no overlap - it1 first + only1 = true; + ++it1; + } else if (it2->end <= it1->begin) { // no overlap - it2 first + only2 = true; + ++it2; + } else if (it2->begin == it1->begin && it2->end == it1->end) { // equals + overlap = true; + ++it1; + ++it2; + } else if (it2->begin <= it1->begin && it2->end >= it1->end) { // 1 c 2 + only2 = true; + overlap = true; + ++it1; + } else if (it1->begin <= it2->begin && it1->end >= it2->end) { // 2 c 1 + only1 = true; + overlap = true; + ++it2; + } else { // Strict overlap. + only1 = true; + only2 = true; + overlap = true; + } + if (only1 == true && only2 == true && overlap == true) return true; + } + if (it1 != intervals_.end()) only1 = true; + if (it2 != iset.intervals_.end()) only2 = true; + return only1 == true && only2 == true && overlap == true; +} + +// Determines if an interval set is contained within the set. Requires intervals +// be normalized. +template +bool IntervalSet::Contains(const IntervalSet &iset) const { + if (iset.Count() > Count()) return false; + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if ((it1->end) <= (it2->begin)) { // No overlap; it1 first. + ++it1; + } else if ((it2->begin) < (it1->begin) || + (it2->end) > (it1->end)) { // No C. + return false; + } else if (it2->end == it1->end) { + ++it1; + ++it2; + } else { + ++it2; + } + } + return it2 == iset.intervals_.end(); +} + +template +std::ostream &operator<<(std::ostream &strm, const IntervalSet &s) { + strm << "{"; + for (T i = 0; i < s.Size(); ++i) { + if (i > 0) { + strm << ","; + } + const auto &interval = s.Intervals()[i]; + strm << "[" << interval.begin << "," << interval.end << ")"; + } + strm << "}"; + return strm; +} + +} // namespace fst + +#endif // FST_INTERVAL_SET_H_ diff --git a/projects/llm_framework/include/fst/invert.h b/projects/llm_framework/include/fst/invert.h new file mode 100644 index 00000000..bd243c62 --- /dev/null +++ b/projects/llm_framework/include/fst/invert.h @@ -0,0 +1,139 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to invert an FST. + +#ifndef FST_INVERT_H_ +#define FST_INVERT_H_ + +#include +#include + + +namespace fst { + +// Mapper to implement inversion of an arc. +template +struct InvertMapper { + using FromArc = A; + using ToArc = A; + + InvertMapper() {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.olabel, arc.ilabel, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { + return MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return InvertProperties(props); + } +}; + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(1) +// +// where V is the number of states and E is the number of arcs. +template +inline void Invert(const Fst &ifst, MutableFst *ofst) { + std::unique_ptr input( + ifst.InputSymbols() ? ifst.InputSymbols()->Copy() : nullptr); + std::unique_ptr output( + ifst.OutputSymbols() ? ifst.OutputSymbols()->Copy() : nullptr); + ArcMap(ifst, ofst, InvertMapper()); + ofst->SetInputSymbols(output.get()); + ofst->SetOutputSymbols(input.get()); +} + +// Destructive variant of the above. +template +inline void Invert(MutableFst *fst) { + std::unique_ptr input( + fst->InputSymbols() ? fst->InputSymbols()->Copy() : nullptr); + std::unique_ptr output( + fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr); + ArcMap(fst, InvertMapper()); + fst->SetInputSymbols(output.get()); + fst->SetOutputSymbols(input.get()); +} + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. This version is a delayed FST. +// +// Complexity: +// +// Time: O(v + e) +// Space: O(1) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time and to visit an input state or arc is assumed and exclusive of +// caching. +template +class InvertFst : public ArcMapFst> { + public: + using Arc = A; + + using Mapper = InvertMapper; + using Impl = internal::ArcMapFstImpl>; + + explicit InvertFst(const Fst &fst) + : ArcMapFst(fst, Mapper()) { + GetMutableImpl()->SetOutputSymbols(fst.InputSymbols()); + GetMutableImpl()->SetInputSymbols(fst.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + InvertFst(const InvertFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Get a copy of this InvertFst. See Fst<>::Copy() for further doc. + InvertFst *Copy(bool safe = false) const override { + return new InvertFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for InvertFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const InvertFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for InvertFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const InvertFst &fst, StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdInvertFst = InvertFst; + +} // namespace fst + +#endif // FST_INVERT_H_ diff --git a/projects/llm_framework/include/fst/isomorphic.h b/projects/llm_framework/include/fst/isomorphic.h new file mode 100644 index 00000000..b100b0a6 --- /dev/null +++ b/projects/llm_framework/include/fst/isomorphic.h @@ -0,0 +1,183 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to test two FSTs are isomorphic, i.e., they are equal up to a state +// and arc re-ordering. FSTs should be deterministic when viewed as +// unweighted automata. + +#ifndef FST_ISOMORPHIC_H_ +#define FST_ISOMORPHIC_H_ + +#include +#include +#include +#include + +#include + +#include + + +namespace fst { +namespace internal { + +// Orders weights for equality checking. +template ::value>::type * = nullptr> +bool WeightCompare(const Weight &w1, const Weight &w2, float delta, + bool *error) { + return NaturalLess()(w1, w2); +} + +template ::value>::type * = nullptr> +bool WeightCompare(const Weight &w1, const Weight &w2, float delta, + bool *error) { + // No natural order; use hash. + const auto q1 = w1.Quantize(delta); + const auto q2 = w2.Quantize(delta); + auto n1 = q1.Hash(); + auto n2 = q2.Hash(); + // Hash not unique; very unlikely to happen. + if (n1 == n2 && q1 != q2) { + VLOG(1) << "Isomorphic: Weight hash collision"; + *error = true; + } + return n1 < n2; +} + +template +class Isomorphism { + using StateId = typename Arc::StateId; + + public: + Isomorphism(const Fst &fst1, const Fst &fst2, float delta) + : fst1_(fst1.Copy()), + fst2_(fst2.Copy()), + delta_(delta), + error_(false), + comp_(delta, &error_) {} + + // Checks if input FSTs are isomorphic. + bool IsIsomorphic() { + if (fst1_->Start() == kNoStateId && fst2_->Start() == kNoStateId) { + return true; + } + if (fst1_->Start() == kNoStateId || fst2_->Start() == kNoStateId) { + return false; + } + PairState(fst1_->Start(), fst2_->Start()); + while (!queue_.empty()) { + const auto &pr = queue_.front(); + if (!IsIsomorphicState(pr.first, pr.second)) return false; + queue_.pop_front(); + } + return true; + } + + bool Error() const { return error_; } + + private: + // Orders arcs for equality checking. + class ArcCompare { + public: + ArcCompare(float delta, bool *error) : delta_(delta), error_(error) {} + + bool operator()(const Arc &arc1, const Arc &arc2) const { + if (arc1.ilabel < arc2.ilabel) return true; + if (arc1.ilabel > arc2.ilabel) return false; + if (arc1.olabel < arc2.olabel) return true; + if (arc1.olabel > arc2.olabel) return false; + return WeightCompare(arc1.weight, arc2.weight, delta_, error_); + } + + private: + float delta_; + bool *error_; + }; + + // Maintains state correspondences and queue. + bool PairState(StateId s1, StateId s2) { + if (state_pairs_.size() <= s1) state_pairs_.resize(s1 + 1, kNoStateId); + if (state_pairs_[s1] == s2) { + return true; // already seen this pair + } else if (state_pairs_[s1] != kNoStateId) { + return false; // s1 already paired with another s2 + } + state_pairs_[s1] = s2; + queue_.push_back(std::make_pair(s1, s2)); + return true; + } + + // Checks if state pair is isomorphic + bool IsIsomorphicState(StateId s1, StateId s2); + + std::unique_ptr> fst1_; + std::unique_ptr> fst2_; + float delta_; // Weight equality delta. + std::vector arcs1_; // For sorting arcs on FST1. + std::vector arcs2_; // For sorting arcs on FST2. + std::vector state_pairs_; // Maintains state correspondences. + std::list> queue_; // Queue of state pairs. + bool error_; // Error flag. + ArcCompare comp_; +}; + +template +bool Isomorphism::IsIsomorphicState(StateId s1, StateId s2) { + if (!ApproxEqual(fst1_->Final(s1), fst2_->Final(s2), delta_)) return false; + auto narcs1 = fst1_->NumArcs(s1); + auto narcs2 = fst2_->NumArcs(s2); + if (narcs1 != narcs2) return false; + ArcIterator> aiter1(*fst1_, s1); + ArcIterator> aiter2(*fst2_, s2); + arcs1_.clear(); + arcs1_.reserve(narcs1); + arcs2_.clear(); + arcs2_.reserve(narcs2); + for (; !aiter1.Done(); aiter1.Next(), aiter2.Next()) { + arcs1_.push_back(aiter1.Value()); + arcs2_.push_back(aiter2.Value()); + } + std::sort(arcs1_.begin(), arcs1_.end(), comp_); + std::sort(arcs2_.begin(), arcs2_.end(), comp_); + for (size_t i = 0; i < arcs1_.size(); ++i) { + const auto &arc1 = arcs1_[i]; + const auto &arc2 = arcs2_[i]; + if (arc1.ilabel != arc2.ilabel) return false; + if (arc1.olabel != arc2.olabel) return false; + if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) return false; + if (!PairState(arc1.nextstate, arc2.nextstate)) return false; + if (i > 0) { // Checks for non-determinism. + const auto &arc0 = arcs1_[i - 1]; + if (arc1.ilabel == arc0.ilabel && arc1.olabel == arc0.olabel && + ApproxEqual(arc1.weight, arc0.weight, delta_)) { + VLOG(1) << "Isomorphic: Non-determinism as an unweighted automaton"; + error_ = true; + return false; + } + } + } + return true; +} + +} // namespace internal + +// Tests if two FSTs have the same states and arcs up to a reordering. +// Inputs should be non-deterministic when viewed as unweighted automata. +template +bool Isomorphic(const Fst &fst1, const Fst &fst2, + float delta = kDelta) { + internal::Isomorphism iso(fst1, fst2, delta); + bool result = iso.IsIsomorphic(); + if (iso.Error()) { + FSTERROR() << "Isomorphic: Cannot determine if inputs are isomorphic"; + return false; + } else { + return result; + } +} + +} // namespace fst + +#endif // FST_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/label-reachable.h b/projects/llm_framework/include/fst/label-reachable.h new file mode 100644 index 00000000..f3d7f2bc --- /dev/null +++ b/projects/llm_framework/include/fst/label-reachable.h @@ -0,0 +1,511 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to determine if a non-epsilon label can be read as the first +// non-epsilon symbol along some path from a given state. + +#ifndef FST_LABEL_REACHABLE_H_ +#define FST_LABEL_REACHABLE_H_ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Stores shareable data for label reachable class copies. +template +class LabelReachableData { + public: + using LabelIntervalSet = IntervalSet *fst, C *mapper) { + ArcMap(fst, mapper); +} + +template +void Map(MutableFst *fst, C mapper) { + ArcMap(fst, mapper); +} + +template +void Map(const Fst &ifst, MutableFst *ofst, C *mapper) { + ArcMap(ifst, ofst, mapper); +} + +template +void Map(const Fst &ifst, MutableFst *ofst, C mapper) { + ArcMap(ifst, ofst, mapper); +} + +using MapFstOptions = ArcMapFstOptions; + +template +class MapFst : public ArcMapFst { + public: + using FromArc = A; + using ToArc = B; + + using StateId = typename ToArc::StateId; + using Weight = typename ToArc::Weight; + + using State = CacheState; + + MapFst(const Fst &fst, const C &mapper, const MapFstOptions &opts) + : ArcMapFst(fst, mapper, opts) {} + + MapFst(const Fst &fst, C *mapper, const MapFstOptions &opts) + : ArcMapFst(fst, mapper, opts) {} + + MapFst(const Fst &fst, const C &mapper) + : ArcMapFst(fst, mapper) {} + + MapFst(const Fst &fst, C *mapper) : ArcMapFst(fst, mapper) {} + + // See Fst<>::Copy() for doc. + MapFst(const MapFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Get a copy of this MapFst. See Fst<>::Copy() for further doc. + MapFst *Copy(bool safe = false) const override { + return new MapFst(*this, safe); + } +}; + +// Specialization for MapFst. +template +class StateIterator> + : public StateIterator> { + public: + explicit StateIterator(const ArcMapFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for MapFst. +template +class ArcIterator> : public ArcIterator> { + public: + ArcIterator(const ArcMapFst &fst, typename A::StateId s) + : ArcIterator>(fst, s) {} +}; + +// For backwards compatibility only; use IdentityArcMapper otherwise. +template +struct IdentityMapper { + using FromArc = A; + using ToArc = A; + + ToArc operator()(const FromArc &arc) const { return arc; } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_MAP_H_ diff --git a/projects/llm_framework/include/fst/mapped-file.h b/projects/llm_framework/include/fst/mapped-file.h new file mode 100644 index 00000000..adb33c28 --- /dev/null +++ b/projects/llm_framework/include/fst/mapped-file.h @@ -0,0 +1,81 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_MAPPED_FILE_H_ +#define FST_MAPPED_FILE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +// A memory region is a simple abstraction for allocated memory or data from +// memory-mapped files. If mmap is null, then data represents an owned region +// of size bytes. Otherwise, mmap and size refer to the mapping and data is a +// casted pointer to a region contained within [mmap, mmap + size). If size is +// 0, then mmap and data refer to a block of memory managed externally by some +// other allocator. The offset is used when allocating memory to providing +// padding for alignment. +struct MemoryRegion { + void *data; + void *mmap; + size_t size; + int offset; +}; + +class MappedFile { + public: + ~MappedFile(); + + void *mutable_data() const { return region_.data; } + + const void *data() const { return region_.data; } + + // Returns a MappedFile object that contains the contents of the input stream + // strm starting from the current file position with size bytes. The memorymap + // bool is advisory, and Map will default to allocating and reading. The + // source argument needs to contain the filename that was used to open the + // input stream. + static MappedFile *Map(std::istream *istrm, bool memorymap, + const string &source, size_t size); + + // Returns a MappedFile object that contains the contents of the file referred + // to by the file descriptor starting from pos with size bytes. If the + // memory mapping fails, nullptr is returned. In contrast to Map(), this + // factory function does not backoff to allocating and reading. + static MappedFile *MapFromFileDescriptor(int fd, int pos, size_t size); + + // Creates a MappedFile object with a new[]'ed block of memory of size. The + // align argument can be used to specify a desired block alignment. + // This is RECOMMENDED FOR INTERNAL USE ONLY as it may change in future + // releases. + static MappedFile *Allocate(size_t size, int align = kArchAlignment); + + // Creates a MappedFile object pointing to a borrowed reference to data. This + // block of memory is not owned by the MappedFile object and will not be + // freed. This is RECOMMENDED FOR INTERNAL USE ONLY, may change in future + // releases. + static MappedFile *Borrow(void *data); + + // Alignment required for mapping structures in bytes. Regions of memory that + // are not aligned upon a 128-bit boundary are read from the file instead. + // This is consistent with the alignment boundary set in ConstFst and + // CompactFst. + static constexpr int kArchAlignment = 16; + + static constexpr size_t kMaxReadChunk = 256 * 1024 * 1024; // 256 MB. + + private: + explicit MappedFile(const MemoryRegion ®ion); + + MemoryRegion region_; + MappedFile(const MappedFile &) = delete; + MappedFile &operator=(const MappedFile &) = delete; +}; +} // namespace fst + +#endif // FST_MAPPED_FILE_H_ diff --git a/projects/llm_framework/include/fst/matcher-fst.h b/projects/llm_framework/include/fst/matcher-fst.h new file mode 100644 index 00000000..61e95820 --- /dev/null +++ b/projects/llm_framework/include/fst/matcher-fst.h @@ -0,0 +1,347 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to add a matcher to an FST. + +#ifndef FST_MATCHER_FST_H_ +#define FST_MATCHER_FST_H_ + +#include +#include + +#include +#include +#include + + +namespace fst { + +// Writeable matchers have the same interface as Matchers (as defined in +// matcher.h) along with the following additional methods: +// +// template +// class Matcher { +// public: +// using FST = F; +// ... +// using MatcherData = ...; // Initialization data. +// +// // Constructor with additional argument for external initialization data; +// // matcher increments its reference count on construction and decrements +// // the reference count, and deletes once the reference count has reached +// // zero. +// Matcher(const FST &fst, MatchType type, MatcherData *data); +// +// // Returns pointer to initialization data that can be passed to a Matcher +// // constructor. +// MatcherData *GetData() const; +// }; + +// The matcher initialization data class must also provide the following +// interface: +// +// class MatcherData { +// public: +// // Required copy constructor. +// MatcherData(const MatcherData &); +// +// // Required I/O methods. +// static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts); +// bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const; +// }; + +// Trivial (no-op) MatcherFst initializer functor. +template +class NullMatcherFstInit { + public: + using MatcherData = typename M::MatcherData; + using Data = AddOnPair; + using Impl = internal::AddOnImpl; + + explicit NullMatcherFstInit(std::shared_ptr *) {} +}; + +// Class adding a matcher to an FST type. Creates a new FST whose name is given +// by N. An optional functor Init can be used to initialize the FST. The Data +// template parameter allows the user to select the type of the add-on. +template < + class F, class M, const char *Name, class Init = NullMatcherFstInit, + class Data = AddOnPair> +class MatcherFst : public ImplToExpandedFst> { + public: + using FST = F; + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + using FstMatcher = M; + using MatcherData = typename FstMatcher::MatcherData; + + using Impl = internal::AddOnImpl; + using D = Data; + + friend class StateIterator>; + friend class ArcIterator>; + + MatcherFst() : ImplToExpandedFst(std::make_shared(FST(), Name)) {} + + explicit MatcherFst(const FST &fst, std::shared_ptr data = nullptr) + : ImplToExpandedFst(data ? CreateImpl(fst, Name, data) + : CreateDataAndImpl(fst, Name)) {} + + explicit MatcherFst(const Fst &fst) + : ImplToExpandedFst(CreateDataAndImpl(fst, Name)) {} + + // See Fst<>::Copy() for doc. + MatcherFst(const MatcherFst &fst, + bool safe = false) + : ImplToExpandedFst(fst, safe) {} + + // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc. + MatcherFst *Copy( + bool safe = false) const override { + return new MatcherFst(*this, safe); + } + + // Read a MatcherFst from an input stream; return nullptr on error + static MatcherFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new MatcherFst( + std::shared_ptr(impl)) + : nullptr; + } + + // Read a MatcherFst from a file; return nullptr on error + // Empty filename reads from standard input + static MatcherFst *Read( + const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new MatcherFst( + std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + void InitStateIterator(StateIteratorData *data) const override { + return GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + return GetImpl()->InitArcIterator(s, data); + } + + FstMatcher *InitMatcher(MatchType match_type) const override { + return new FstMatcher(&GetFst(), match_type, GetSharedData(match_type)); + } + + const FST &GetFst() const { return GetImpl()->GetFst(); } + + const Data *GetAddOn() const { return GetImpl()->GetAddOn(); } + + std::shared_ptr GetSharedAddOn() const { + return GetImpl()->GetSharedAddOn(); + } + + const MatcherData *GetData(MatchType match_type) const { + const auto *data = GetAddOn(); + return match_type == MATCH_INPUT ? data->First() : data->Second(); + } + + std::shared_ptr GetSharedData(MatchType match_type) const { + const auto *data = GetAddOn(); + return match_type == MATCH_INPUT ? data->SharedFirst() + : data->SharedSecond(); + } + + protected: + using ImplToFst>::GetImpl; + + static std::shared_ptr CreateDataAndImpl(const FST &fst, + const string &name) { + FstMatcher imatcher(fst, MATCH_INPUT); + FstMatcher omatcher(fst, MATCH_OUTPUT); + return CreateImpl(fst, name, + std::make_shared(imatcher.GetSharedData(), + omatcher.GetSharedData())); + } + + static std::shared_ptr CreateDataAndImpl(const Fst &fst, + const string &name) { + FST result(fst); + return CreateDataAndImpl(result, name); + } + + static std::shared_ptr CreateImpl(const FST &fst, const string &name, + std::shared_ptr data) { + auto impl = std::make_shared(fst, name); + impl->SetAddOn(data); + Init init(&impl); + return impl; + } + + explicit MatcherFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + private: + MatcherFst &operator=(const MatcherFst &) = delete; +}; + +// Specialization for MatcherFst. +template +class StateIterator> + : public StateIterator { + public: + explicit StateIterator(const MatcherFst &fst) + : StateIterator(fst.GetImpl()->GetFst()) {} +}; + +// Specialization for MatcherFst. +template +class ArcIterator> : public ArcIterator { + public: + using StateId = typename FST::Arc::StateId; + + ArcIterator(const MatcherFst &fst, + typename FST::Arc::StateId s) + : ArcIterator(fst.GetImpl()->GetFst(), s) {} +}; + +// Specialization for MatcherFst. +template +class Matcher> { + public: + using FST = MatcherFst; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + + Matcher(const FST &fst, MatchType match_type) + : matcher_(fst.InitMatcher(match_type)) {} + + Matcher(const Matcher &matcher) : matcher_(matcher.matcher_->Copy()) {} + + Matcher *Copy() const { return new Matcher(*this); } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { matcher_->SetState(s); } + + bool Find(Label label) { return matcher_->Find(label); } + + bool Done() const { return matcher_->Done(); } + + const Arc &Value() const { return matcher_->Value(); } + + void Next() { matcher_->Next(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + private: + std::unique_ptr matcher_; +}; + +// Specialization for MatcherFst. +template +class LookAheadMatcher> { + public: + using FST = MatcherFst; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + LookAheadMatcher(const FST &fst, MatchType match_type) + : matcher_(fst.InitMatcher(match_type)) {} + + LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false) + : matcher_(matcher.matcher_->Copy(safe)) {} + + // General matcher methods. + LookAheadMatcher *Copy(bool safe = false) const { + return new LookAheadMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { matcher_->SetState(s); } + + bool Find(Label label) { return matcher_->Find(label); } + + bool Done() const { return matcher_->Done(); } + + const Arc &Value() const { return matcher_->Value(); } + + void Next() { matcher_->Next(); } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + bool LookAheadLabel(Label label) const { + return matcher_->LookAheadLabel(label); + } + + bool LookAheadFst(const Fst &fst, StateId s) { + return matcher_->LookAheadFst(fst, s); + } + + Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); } + + bool LookAheadPrefix(Arc *arc) const { + return matcher_->LookAheadPrefix(arc); + } + + void InitLookAheadFst(const Fst &fst, bool copy = false) { + matcher_->InitLookAheadFst(fst, copy); + } + + private: + std::unique_ptr matcher_; +}; + +// Useful aliases when using StdArc. + +extern const char arc_lookahead_fst_type[]; + +using StdArcLookAheadFst = + MatcherFst, + ArcLookAheadMatcher>>, + arc_lookahead_fst_type>; + +extern const char ilabel_lookahead_fst_type[]; +extern const char olabel_lookahead_fst_type[]; + +constexpr auto ilabel_lookahead_flags = + kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; + +constexpr auto olabel_lookahead_flags = + kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; + +using StdILabelLookAheadFst = MatcherFst< + ConstFst, + LabelLookAheadMatcher>, + ilabel_lookahead_flags, FastLogAccumulator>, + ilabel_lookahead_fst_type, LabelLookAheadRelabeler>; + +using StdOLabelLookAheadFst = MatcherFst< + ConstFst, + LabelLookAheadMatcher>, + olabel_lookahead_flags, FastLogAccumulator>, + olabel_lookahead_fst_type, LabelLookAheadRelabeler>; + +} // namespace fst + +#endif // FST_MATCHER_FST_H_ diff --git a/projects/llm_framework/include/fst/matcher.h b/projects/llm_framework/include/fst/matcher.h new file mode 100644 index 00000000..d9528d68 --- /dev/null +++ b/projects/llm_framework/include/fst/matcher.h @@ -0,0 +1,1575 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes to allow matching labels leaving FST states. + +#ifndef FST_MATCHER_H_ +#define FST_MATCHER_H_ + +#include +#include +#include +#include + +#include + +#include // for all internal FST accessors. + + +namespace fst { + +// Matchers find and iterate through requested labels at FST states. In the +// simplest form, these are just some associative map or search keyed on labels. +// More generally, they may implement matching special labels that represent +// sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher +// interface is: +// +// template +// class Matcher { +// public: +// using FST = F; +// using Arc = typename FST::Arc; +// using Label = typename Arc::Label; +// using StateId = typename Arc::StateId; +// using Weight = typename Arc::Weight; +// +// // Required constructors. Note: +// // -- the constructors that copy the FST arg are useful for +// // letting the matcher manage the FST through copies +// // (esp with 'safe' copies); e.g. ComposeFst depends on this. +// // -- the constructor that does not copy is useful when the +// // the FST is mutated during the lifetime of the matcher +// // (o.w. the matcher would have its own unmutated deep copy). +// +// // This makes a copy of the FST. +// Matcher(const FST &fst, MatchType type); +// // This doesn't copy the FST. +// Matcher(const FST *fst, MatchType type); +// // This makes a copy of the FST. +// // See Copy() below. +// Matcher(const Matcher &matcher, bool safe = false); +// +// // If safe = true, the copy is thread-safe. See Fst<>::Copy() for +// // further doc. +// Matcher *Copy(bool safe = false) const override; +// +// // Returns the match type that can be provided (depending on compatibility +// of the input FST). It is either the requested match type, MATCH_NONE, or +// MATCH_UNKNOWN. If test is false, a costly testing is avoided, but +// MATCH_UNKNOWN may be returned. If test is true, a definite answer is +// returned, but may involve more costly computation (e.g., visiting the FST). +// MatchType Type(bool test) const override; +// +// // Specifies the current state. +// void SetState(StateId s) final; +// +// // Finds matches to a label at the current state, returning true if a match +// // found. kNoLabel matches any non-consuming transitions, e.g., epsilon +// // transitions, which do not require a matching symbol. +// bool Find(Label label) final; +// +// // Iterator methods. Note that initially and after SetState() these have +// undefined behavior until Find() is called. +// +// bool Done() const final; +// +// const Arc &Value() const final; +// +// void Next() final; +// +// // Returns final weight of a state. +// Weight Final(StateId) const final; +// +// // Indicates preference for being the side used for matching in +// // composition. If the value is kRequirePriority, then it is +// // mandatory that it be used. Calling this method without passing the +// // current state of the matcher invalidates the state of the matcher. +// ssize_t Priority(StateId s) final; +// +// // This specifies the known FST properties as viewed from this matcher. It +// // takes as argument the input FST's known properties. +// uint64 Properties(uint64 props) const override; +// +// // Returns matcher flags. +// uint32 Flags() const override; +// +// // Returns matcher FST. +// const FST &GetFst() const override; +// }; + +// Basic matcher flags. + +// Matcher needs to be used as the matching side in composition for +// at least one state (has kRequirePriority). +constexpr uint32 kRequireMatch = 0x00000001; + +// Flags used for basic matchers (see also lookahead.h). +constexpr uint32 kMatcherFlags = kRequireMatch; + +// Matcher priority that is mandatory. +constexpr ssize_t kRequirePriority = -1; + +// Matcher interface, templated on the Arc definition; used for matcher +// specializations that are returned by the InitMatcher FST method. +template +class MatcherBase { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual ~MatcherBase() {} + + // Virtual interface. + + virtual MatcherBase *Copy(bool safe = false) const = 0; + virtual MatchType Type(bool) const = 0; + virtual void SetState(StateId) = 0; + virtual bool Find(Label) = 0; + virtual bool Done() const = 0; + virtual const Arc &Value() const = 0; + virtual void Next() = 0; + virtual const Fst &GetFst() const = 0; + virtual uint64 Properties(uint64) const = 0; + + // Trivial implementations that can be used by derived classes. Full + // devirtualization is expected for any derived class marked final. + virtual uint32 Flags() const { return 0; } + + virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); } + + virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); } +}; + +// A matcher that expects sorted labels on the side to be matched. +// If match_type == MATCH_INPUT, epsilons match the implicit self-loop +// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any +// actual epsilon transitions. If match_type == MATCH_OUTPUT, then +// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched. +template +class SortedMatcher : public MatcherBase { + public: + using FST = F; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MatcherBase::Flags; + using MatcherBase::Properties; + + // Labels >= binary_label will be searched for by binary search; + // o.w. linear search is used. + // This makes a copy of the FST. + SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1) + : SortedMatcher(fst.Copy(), match_type, binary_label) { + owned_fst_.reset(&fst_); + } + + // Labels >= binary_label will be searched for by binary search; + // o.w. linear search is used. + // This doesn't copy the FST. + SortedMatcher(const FST *fst, MatchType match_type, Label binary_label = 1) + : fst_(*fst), + state_(kNoStateId), + aiter_(nullptr), + match_type_(match_type), + binary_label_(binary_label), + match_label_(kNoLabel), + narcs_(0), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false), + aiter_pool_(1) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + std::swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "SortedMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + SortedMatcher(const SortedMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + state_(kNoStateId), + aiter_(nullptr), + match_type_(matcher.match_type_), + binary_label_(matcher.binary_label_), + match_label_(kNoLabel), + narcs_(0), + loop_(matcher.loop_), + error_(matcher.error_), + aiter_pool_(1) {} + + ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); } + + SortedMatcher *Copy(bool safe = false) const override { + return new SortedMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if (match_type_ == MATCH_NONE) return match_type_; + const auto true_prop = + match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; + const auto false_prop = + match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; + const auto props = fst_.Properties(true_prop | false_prop, test); + if (props & true_prop) { + return match_type_; + } else if (props & false_prop) { + return MATCH_NONE; + } else { + return MATCH_UNKNOWN; + } + } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + if (match_type_ == MATCH_NONE) { + FSTERROR() << "SortedMatcher: Bad match type"; + error_ = true; + } + Destroy(aiter_, &aiter_pool_); + aiter_ = new (&aiter_pool_) ArcIterator(fst_, s); + aiter_->SetFlags(kArcNoCache, kArcNoCache); + narcs_ = internal::NumArcs(fst_, s); + loop_.nextstate = s; + } + + bool Find(Label match_label) final { + exact_match_ = true; + if (error_) { + current_loop_ = false; + match_label_ = kNoLabel; + return false; + } + current_loop_ = match_label == 0; + match_label_ = match_label == kNoLabel ? 0 : match_label; + if (Search()) { + return true; + } else { + return current_loop_; + } + } + + // Positions matcher to the first position where inserting match_label would + // maintain the sort order. + void LowerBound(Label label) { + exact_match_ = false; + current_loop_ = false; + if (error_) { + match_label_ = kNoLabel; + return; + } + match_label_ = label; + Search(); + } + + // After Find(), returns false if no more exact matches. + // After LowerBound(), returns false if no more arcs. + bool Done() const final { + if (current_loop_) return false; + if (aiter_->Done()) return true; + if (!exact_match_) return false; + aiter_->SetFlags(match_type_ == MATCH_INPUT ? + kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + return GetLabel() != match_label_; + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); + return aiter_->Value(); + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + aiter_->Next(); + } + } + + Weight Final(StateId s) const final { + return MatcherBase::Final(s); + } + + ssize_t Priority(StateId s) final { + return MatcherBase::Priority(s); + } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops | (error_ ? kError : 0); + } + + size_t Position() const { return aiter_ ? aiter_->Position() : 0; } + + private: + Label GetLabel() const { + const auto &arc = aiter_->Value(); + return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel; + } + + bool BinarySearch(); + bool LinearSearch(); + bool Search(); + + std::unique_ptr owned_fst_; // FST ptr if owned. + const FST &fst_; // FST for matching. + StateId state_; // Matcher state. + ArcIterator *aiter_; // Iterator for current state. + MatchType match_type_; // Type of match to perform. + Label binary_label_; // Least label for binary search. + Label match_label_; // Current label to be matched. + size_t narcs_; // Current state arc count. + Arc loop_; // For non-consuming symbols. + bool current_loop_; // Current arc is the implicit loop. + bool exact_match_; // Exact match or lower bound? + bool error_; // Error encountered? + MemoryPool> aiter_pool_; // Pool of arc iterators. +}; + +// Returns true iff match to match_label_. The arc iterator is positioned at the +// lower bound, that is, the first element greater than or equal to +// match_label_, or the end if all elements are less than match_label_. +// If multiple elements are equal to the `match_label_`, returns the rightmost +// one. +template +inline bool SortedMatcher::BinarySearch() { + size_t size = narcs_; + if (size == 0) { + return false; + } + size_t high = size - 1; + while (size > 1) { + const size_t half = size / 2; + const size_t mid = high - half; + aiter_->Seek(mid); + if (GetLabel() >= match_label_) { + high = mid; + } + size -= half; + } + aiter_->Seek(high); + const auto label = GetLabel(); + if (label == match_label_) { + return true; + } + if (label < match_label_) { + aiter_->Next(); + } + return false; +} + +// Returns true iff match to match_label_, positioning arc iterator at lower +// bound. +template +inline bool SortedMatcher::LinearSearch() { + for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { + const auto label = GetLabel(); + if (label == match_label_) return true; + if (label > match_label_) break; + } + return false; +} + +// Returns true iff match to match_label_, positioning arc iterator at lower +// bound. +template +inline bool SortedMatcher::Search() { + aiter_->SetFlags(match_type_ == MATCH_INPUT ? + kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + if (match_label_ >= binary_label_) { + return BinarySearch(); + } else { + return LinearSearch(); + } +} + +// A matcher that stores labels in a per-state hash table populated upon the +// first visit to that state. Sorting is not required. Treatment of +// epsilons are the same as with SortedMatcher. +template +class HashMatcher : public MatcherBase { + public: + using FST = F; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MatcherBase::Flags; + using MatcherBase::Final; + using MatcherBase::Priority; + + // This makes a copy of the FST. + HashMatcher(const FST &fst, MatchType match_type) + : HashMatcher(fst.Copy(), match_type) { + owned_fst_.reset(&fst_); + } + + // This doesn't copy the FST. + HashMatcher(const FST *fst, MatchType match_type) + : fst_(*fst), + state_(kNoStateId), + match_type_(match_type), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false), + state_table_(std::make_shared()) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + std::swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "HashMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + HashMatcher(const HashMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + state_(kNoStateId), + match_type_(matcher.match_type_), + loop_(matcher.loop_), + error_(matcher.error_), + state_table_( + safe ? std::make_shared() : matcher.state_table_) {} + + HashMatcher *Copy(bool safe = false) const override { + return new HashMatcher(*this, safe); + } + + // The argument is ignored as there are no relevant properties to test. + MatchType Type(bool test) const override { return match_type_; } + + void SetState(StateId s) final; + + bool Find(Label label) final { + current_loop_ = label == 0; + if (label == 0) { + Search(label); + return true; + } + if (label == kNoLabel) label = 0; + return Search(label); + } + + bool Done() const final { + if (current_loop_) return false; + return label_it_ == label_end_; + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + aiter_->Seek(label_it_->second); + return aiter_->Value(); + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + ++label_it_; + } + } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops | (error_ ? kError : 0); + } + + private: + Label GetLabel() const { + const auto &arc = aiter_->Value(); + return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel; + } + + bool Search(Label match_label); + + using LabelTable = std::unordered_multimap; + using StateTable = std::unordered_map>; + + std::unique_ptr owned_fst_; // ptr to FST if owned. + const FST &fst_; // FST for matching. + StateId state_; // Matcher state. + MatchType match_type_; + Arc loop_; // The implicit loop itself. + bool current_loop_; // Is the current arc the implicit loop? + bool error_; // Error encountered? + std::unique_ptr> aiter_; + std::shared_ptr state_table_; // Table from state to label table. + LabelTable *label_table_; // Pointer to current state's label table. + typename LabelTable::iterator label_it_; // Position for label. + typename LabelTable::iterator label_end_; // Position for last label + 1. +}; + +template +void HashMatcher::SetState(typename FST::Arc::StateId s) { + if (state_ == s) return; + // Resets everything for the state. + state_ = s; + loop_.nextstate = state_; + aiter_.reset(new ArcIterator(fst_, state_)); + if (match_type_ == MATCH_NONE) { + FSTERROR() << "HashMatcher: Bad match type"; + error_ = true; + } + // Attempts to insert a new label table. + auto it_and_success = state_table_->emplace( + state_, std::unique_ptr(new LabelTable())); + // Sets instance's pointer to the label table for this state. + label_table_ = it_and_success.first->second.get(); + // If it already exists, no additional work is done and we simply return. + if (!it_and_success.second) return; + // Otherwise, populate this new table. + // Populates the label table. + label_table_->reserve(internal::NumArcs(fst_, state_)); + const auto aiter_flags = + (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) | + kArcNoCache; + aiter_->SetFlags(aiter_flags, kArcFlags); + for (; !aiter_->Done(); aiter_->Next()) { + label_table_->emplace(GetLabel(), aiter_->Position()); + } + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); +} + +template +inline bool HashMatcher::Search(typename FST::Arc::Label match_label) { + auto range = label_table_->equal_range(match_label); + label_it_ = range.first; + label_end_ = range.second; + if (label_it_ == label_end_) return false; + aiter_->Seek(label_it_->second); + return true; +} + +// Specifies whether we rewrite both the input and output sides during matching. +enum MatcherRewriteMode { + MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor. + MATCHER_REWRITE_ALWAYS, + MATCHER_REWRITE_NEVER +}; + +// For any requested label that doesn't match at a state, this matcher +// considers the *unique* transition that matches the label 'phi_label' +// (phi = 'fail'), and recursively looks for a match at its +// destination. When 'phi_loop' is true, if no match is found but a +// phi self-loop is found, then the phi transition found is returned +// with the phi_label rewritten as the requested label (both sides if +// an acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'phi_label'). If 'phi_label' is +// kNoLabel, this special matching is not done. PhiMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by PhiMatcher. The user can instead pass in this +// object; in that case, PhiMatcher takes its ownership. +// Phi non-determinism not supported. No non-consuming symbols other +// than epsilon supported with the underlying template argument matcher. +template +class PhiMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + phi_label_(phi_label), + state_(kNoStateId), + phi_loop_(phi_loop), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "PhiMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + PhiMatcher(const FST *fst, MatchType match_type, Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : PhiMatcher(*fst, match_type, phi_label, phi_loop, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + + // This makes a copy of the FST. + PhiMatcher(const PhiMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + phi_label_(matcher.phi_label_), + rewrite_both_(matcher.rewrite_both_), + state_(kNoStateId), + phi_loop_(matcher.phi_loop_), + error_(matcher.error_) {} + + PhiMatcher *Copy(bool safe = false) const override { + return new PhiMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + matcher_->SetState(s); + state_ = s; + has_phi_ = phi_label_ != kNoLabel; + } + + bool Find(Label match_label) final; + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { + return matcher_->Value(); + } else if (phi_match_ == 0) { // Virtual epsilon loop. + phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); + if (match_type_ == MATCH_OUTPUT) { + std::swap(phi_arc_.ilabel, phi_arc_.olabel); + } + return phi_arc_; + } else { + phi_arc_ = matcher_->Value(); + phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); + if (phi_match_ != kNoLabel) { // Phi loop match. + if (rewrite_both_) { + if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_; + if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_; + } else if (match_type_ == MATCH_INPUT) { + phi_arc_.ilabel = phi_match_; + } else { + phi_arc_.olabel = phi_match_; + } + } + return phi_arc_; + } + } + + void Next() final { matcher_->Next(); } + + Weight Final(StateId s) const final { + auto weight = matcher_->Final(s); + if (phi_label_ == kNoLabel || weight != Weight::Zero()) { + return weight; + } + weight = Weight::One(); + matcher_->SetState(s); + while (matcher_->Final(s) == Weight::Zero()) { + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break; + weight = Times(weight, matcher_->Value().weight); + if (s == matcher_->Value().nextstate) { + return Weight::Zero(); // Does not follow phi self-loops. + } + s = matcher_->Value().nextstate; + matcher_->SetState(s); + } + weight = Times(weight, matcher_->Final(s)); + return weight; + } + + ssize_t Priority(StateId s) final { + if (phi_label_ != kNoLabel) { + matcher_->SetState(s); + const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_); + return has_phi ? kRequirePriority : matcher_->Priority(s); + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label PhiLabel() const { return phi_label_; } + + private: + mutable std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label phi_label_; // Label that represents the phi transition. + bool rewrite_both_; // Rewrite both sides when both are phi_label_? + bool has_phi_; // Are there possibly phis at the current state? + Label phi_match_; // Current label that matches phi loop. + mutable Arc phi_arc_; // Arc to return. + StateId state_; // Matcher state. + Weight phi_weight_; // Product of the weights of phi transitions taken. + bool phi_loop_; // When true, phi self-loop are allowed and treated + // as rho (required for Aho-Corasick). + bool error_; // Error encountered? + + PhiMatcher &operator=(const PhiMatcher &) = delete; +}; + +template +inline bool PhiMatcher::Find(Label label) { + if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { + FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; + error_ = true; + return false; + } + matcher_->SetState(state_); + phi_match_ = kNoLabel; + phi_weight_ = Weight::One(); + // If phi_label_ == 0, there are no more true epsilon arcs. + if (phi_label_ == 0) { + if (label == kNoLabel) { + return false; + } + if (label == 0) { // but a virtual epsilon loop needs to be returned. + if (!matcher_->Find(kNoLabel)) { + return matcher_->Find(0); + } else { + phi_match_ = 0; + return true; + } + } + } + if (!has_phi_ || label == 0 || label == kNoLabel) { + return matcher_->Find(label); + } + auto s = state_; + while (!matcher_->Find(label)) { + // Look for phi transition (if phi_label_ == 0, we need to look + // for -1 to avoid getting the virtual self-loop) + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false; + if (phi_loop_ && matcher_->Value().nextstate == s) { + phi_match_ = label; + return true; + } + phi_weight_ = Times(phi_weight_, matcher_->Value().weight); + s = matcher_->Value().nextstate; + matcher_->Next(); + if (!matcher_->Done()) { + FSTERROR() << "PhiMatcher: Phi non-determinism not supported"; + error_ = true; + } + matcher_->SetState(s); + } + return true; +} + +template +inline uint64 PhiMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoIEpsilons; + } + if (rewrite_both_) { + return outprops & + ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kODeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoOEpsilons; + } + if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kIDeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "PhiMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// For any requested label that doesn't match at a state, this matcher +// considers all transitions that match the label 'rho_label' (rho = +// 'rest'). Each such rho transition found is returned with the +// rho_label rewritten as the requested label (both sides if an +// acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'rho_label'). If 'rho_label' is +// kNoLabel, this special matching is not done. RhoMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by RhoMatcher. The user can instead pass in this +// object; in that case, RhoMatcher takes its ownership. +// No non-consuming symbols other than epsilon supported with +// the underlying template argument matcher. +template +class RhoMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + rho_label_(rho_label), + error_(false), + state_(kNoStateId), + has_rho_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "RhoMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rho_label == 0) { + FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label"; + rho_label_ = kNoLabel; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + RhoMatcher(const FST *fst, MatchType match_type, Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : RhoMatcher(*fst, match_type, rho_label, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + // This makes a copy of the FST. + RhoMatcher(const RhoMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + rho_label_(matcher.rho_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_), + state_(kNoStateId), + has_rho_(false) {} + + RhoMatcher *Copy(bool safe = false) const override { + return new RhoMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + matcher_->SetState(s); + has_rho_ = rho_label_ != kNoLabel; + } + + bool Find(Label label) final { + if (label == rho_label_ && rho_label_ != kNoLabel) { + FSTERROR() << "RhoMatcher::Find: bad label (rho)"; + error_ = true; + return false; + } + if (matcher_->Find(label)) { + rho_match_ = kNoLabel; + return true; + } else if (has_rho_ && label != 0 && label != kNoLabel && + (has_rho_ = matcher_->Find(rho_label_))) { + rho_match_ = label; + return true; + } else { + return false; + } + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if (rho_match_ == kNoLabel) { + return matcher_->Value(); + } else { + rho_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_; + if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_; + } else if (match_type_ == MATCH_INPUT) { + rho_arc_.ilabel = rho_match_; + } else { + rho_arc_.olabel = rho_match_; + } + return rho_arc_; + } + } + + void Next() final { matcher_->Next(); } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { + state_ = s; + matcher_->SetState(s); + has_rho_ = matcher_->Find(rho_label_); + if (has_rho_) { + return kRequirePriority; + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label RhoLabel() const { return rho_label_; } + + private: + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label rho_label_; // Label that represents the rho transition + bool rewrite_both_; // Rewrite both sides when both are rho_label_? + Label rho_match_; // Current label that matches rho transition. + mutable Arc rho_arc_; // Arc to return when rho match. + bool error_; // Error encountered? + StateId state_; // Matcher state. + bool has_rho_; // Are there possibly rhos at the current state? +}; + +template +inline uint64 RhoMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (rewrite_both_) { + return outprops & + ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kODeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kIDeterministic | kAcceptor | kString | kOLabelSorted | + kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "RhoMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// For any requested label, this matcher considers all transitions +// that match the label 'sigma_label' (sigma = "any"), and this in +// additions to transitions with the requested label. Each such sigma +// transition found is returned with the sigma_label rewritten as the +// requested label (both sides if an acceptor, or if 'rewrite_both' is +// true and both input and output labels of the found transition are +// 'sigma_label'). If 'sigma_label' is kNoLabel, this special +// matching is not done. SigmaMatcher is templated itself on a +// matcher, which is used to perform the underlying matching. By +// default, the underlying matcher is constructed by SigmaMatcher. +// The user can instead pass in this object; in that case, +// SigmaMatcher takes its ownership. No non-consuming symbols other +// than epsilon supported with the underlying template argument matcher. +template +class SigmaMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + SigmaMatcher(const FST &fst, MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + sigma_label_(sigma_label), + error_(false), + state_(kNoStateId) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "SigmaMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (sigma_label == 0) { + FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label"; + sigma_label_ = kNoLabel; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + SigmaMatcher(const FST *fst, MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : SigmaMatcher(*fst, match_type, sigma_label, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + // This makes a copy of the FST. + SigmaMatcher(const SigmaMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + sigma_label_(matcher.sigma_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_), + state_(kNoStateId) {} + + SigmaMatcher *Copy(bool safe = false) const override { + return new SigmaMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + matcher_->SetState(s); + has_sigma_ = + (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false; + } + + bool Find(Label match_label) final { + match_label_ = match_label; + if (match_label == sigma_label_ && sigma_label_ != kNoLabel) { + FSTERROR() << "SigmaMatcher::Find: bad label (sigma)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + sigma_match_ = kNoLabel; + return true; + } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel && + matcher_->Find(sigma_label_)) { + sigma_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if (sigma_match_ == kNoLabel) { + return matcher_->Value(); + } else { + sigma_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_; + if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_; + } else if (match_type_ == MATCH_INPUT) { + sigma_arc_.ilabel = sigma_match_; + } else { + sigma_arc_.olabel = sigma_match_; + } + return sigma_arc_; + } + } + + void Next() final { + matcher_->Next(); + if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) && + (match_label_ > 0)) { + matcher_->Find(sigma_label_); + sigma_match_ = match_label_; + } + } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { + if (sigma_label_ != kNoLabel) { + SetState(s); + return has_sigma_ ? kRequirePriority : matcher_->Priority(s); + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label SigmaLabel() const { return sigma_label_; } + + private: + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label sigma_label_; // Label that represents the sigma transition. + bool rewrite_both_; // Rewrite both sides when both are sigma_label_? + bool has_sigma_; // Are there sigmas at the current state? + Label sigma_match_; // Current label that matches sigma transition. + mutable Arc sigma_arc_; // Arc to return when sigma match. + Label match_label_; // Label being matched. + bool error_; // Error encountered? + StateId state_; // Matcher state. +}; + +template +inline uint64 SigmaMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted | kString); + } else if (match_type_ == MATCH_INPUT) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kILabelSorted | kNotILabelSorted | kString | + kAcceptor); + } else if (match_type_ == MATCH_OUTPUT) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString | + kAcceptor); + } else { + // Shouldn't ever get here. + FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// Flags for MultiEpsMatcher. + +// Return multi-epsilon arcs for Find(kNoLabel). +const uint32 kMultiEpsList = 0x00000001; + +// Return a kNolabel loop for Find(multi_eps). +const uint32 kMultiEpsLoop = 0x00000002; + +// MultiEpsMatcher: allows treating multiple non-0 labels as +// non-consuming labels in addition to 0 that is always +// non-consuming. Precise behavior controlled by 'flags' argument. By +// default, the underlying matcher is constructed by +// MultiEpsMatcher. The user can instead pass in this object; in that +// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is +// true. +template +class MultiEpsMatcher { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + MultiEpsMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = nullptr, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + Init(match_type); + } + + // This doesn't copy the FST. + MultiEpsMatcher(const FST *fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = nullptr, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + Init(match_type); + } + + // This makes a copy of the FST. + MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + flags_(matcher.flags_), + own_matcher_(true), + multi_eps_labels_(matcher.multi_eps_labels_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ~MultiEpsMatcher() { + if (own_matcher_) delete matcher_; + } + + MultiEpsMatcher *Copy(bool safe = false) const { + return new MultiEpsMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId state) { + matcher_->SetState(state); + loop_.nextstate = state; + } + + bool Find(Label label); + + bool Done() const { return done_; } + + const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); } + + void Next() { + if (!current_loop_) { + matcher_->Next(); + done_ = matcher_->Done(); + if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) { + ++multi_eps_iter_; + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) { + ++multi_eps_iter_; + } + if (multi_eps_iter_ != multi_eps_labels_.End()) { + done_ = false; + } else { + done_ = !matcher_->Find(kNoLabel); + } + } + } else { + done_ = true; + } + } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + const M *GetMatcher() const { return matcher_; } + + Weight Final(StateId s) const { return matcher_->Final(s); } + + uint32 Flags() const { return matcher_->Flags(); } + + ssize_t Priority(StateId s) { return matcher_->Priority(s); } + + void AddMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Insert(label); + } + } + + void RemoveMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Erase(label); + } + } + + void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); } + + private: + void Init(MatchType match_type) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + M *matcher_; + uint32 flags_; + bool own_matcher_; // Does this class delete the matcher? + + // Multi-eps label set. + CompactSet multi_eps_labels_; + typename CompactSet::const_iterator multi_eps_iter_; + + bool current_loop_; // Current arc is the implicit loop? + mutable Arc loop_; // For non-consuming symbols. + bool done_; // Matching done? + + MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete; +}; + +template +inline bool MultiEpsMatcher::Find(Label label) { + multi_eps_iter_ = multi_eps_labels_.End(); + current_loop_ = false; + bool ret; + if (label == 0) { + ret = matcher_->Find(0); + } else if (label == kNoLabel) { + if (flags_ & kMultiEpsList) { + // Returns all non-consuming arcs (including epsilon). + multi_eps_iter_ = multi_eps_labels_.Begin(); + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) { + ++multi_eps_iter_; + } + if (multi_eps_iter_ != multi_eps_labels_.End()) { + ret = true; + } else { + ret = matcher_->Find(kNoLabel); + } + } else { + // Returns all epsilon arcs. + ret = matcher_->Find(kNoLabel); + } + } else if ((flags_ & kMultiEpsLoop) && + multi_eps_labels_.Find(label) != multi_eps_labels_.End()) { + // Returns implicit loop. + current_loop_ = true; + ret = true; + } else { + ret = matcher_->Find(label); + } + done_ = !ret; + return ret; +} + +// This class discards any implicit matches (e.g., the implicit epsilon +// self-loops in the SortedMatcher). Matchers are most often used in +// composition/intersection where the implicit matches are needed +// e.g. for epsilon processing. However, if a matcher is simply being +// used to look-up explicit label matches, this class saves the user +// from having to check for and discard the unwanted implicit matches +// themselves. +template +class ExplicitMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + error_(false) {} + + // This doesn't copy the FST. + ExplicitMatcher(const FST *fst, MatchType match_type, M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + error_(false) {} + + // This makes a copy of the FST. + ExplicitMatcher(const ExplicitMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + error_(matcher.error_) {} + + ExplicitMatcher *Copy(bool safe = false) const override { + return new ExplicitMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { matcher_->SetState(s); } + + bool Find(Label label) final { + matcher_->Find(label); + CheckArc(); + return !Done(); + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { return matcher_->Value(); } + + void Next() final { + matcher_->Next(); + CheckArc(); + } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { return matcher_->Priority(s); } + + const FST &GetFst() const final { return matcher_->GetFst(); } + + uint64 Properties(uint64 inprops) const override { + return matcher_->Properties(inprops); + } + + const M *GetMatcher() const { return matcher_.get(); } + + uint32 Flags() const override { return matcher_->Flags(); } + + private: + // Checks current arc if available and explicit. If not available, stops. If + // not explicit, checks next ones. + void CheckArc() { + for (; !matcher_->Done(); matcher_->Next()) { + const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel + : matcher_->Value().olabel; + if (label != kNoLabel) return; + } + } + + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + bool error_; // Error encountered? +}; + +// Generic matcher, templated on the FST definition. +// +// Here is a typical use: +// +// Matcher matcher(fst, MATCH_INPUT); +// matcher.SetState(state); +// if (matcher.Find(label)) +// for (; !matcher.Done(); matcher.Next()) { +// auto &arc = matcher.Value(); +// ... +// } +template +class Matcher { + public: + using FST = F; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + Matcher(const FST &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + base_(owned_fst_->InitMatcher(match_type)) { + if (!base_) base_.reset(new SortedMatcher(owned_fst_.get(), + match_type)); + } + + // This doesn't copy the FST. + Matcher(const FST *fst, MatchType match_type) + : base_(fst->InitMatcher(match_type)) { + if (!base_) base_.reset(new SortedMatcher(fst, match_type)); + } + + // This makes a copy of the FST. + Matcher(const Matcher &matcher, bool safe = false) + : base_(matcher.base_->Copy(safe)) { } + + // Takes ownership of the provided matcher. + explicit Matcher(MatcherBase *base_matcher) + : base_(base_matcher) { } + + Matcher *Copy(bool safe = false) const { + return new Matcher(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + + void SetState(StateId s) { base_->SetState(s); } + + bool Find(Label label) { return base_->Find(label); } + + bool Done() const { return base_->Done(); } + + const Arc &Value() const { return base_->Value(); } + + void Next() { base_->Next(); } + + const FST &GetFst() const { + return static_cast(base_->GetFst()); + } + + uint64 Properties(uint64 props) const { return base_->Properties(props); } + + Weight Final(StateId s) const { return base_->Final(s); } + + uint32 Flags() const { return base_->Flags() & kMatcherFlags; } + + ssize_t Priority(StateId s) { return base_->Priority(s); } + + private: + std::unique_ptr owned_fst_; + std::unique_ptr> base_; +}; + +} // namespace fst + +#endif // FST_MATCHER_H_ diff --git a/projects/llm_framework/include/fst/memory.h b/projects/llm_framework/include/fst/memory.h new file mode 100644 index 00000000..c1f0bddc --- /dev/null +++ b/projects/llm_framework/include/fst/memory.h @@ -0,0 +1,443 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST memory utilities. + +#ifndef FST_MEMORY_H_ +#define FST_MEMORY_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Default block allocation size. +constexpr int kAllocSize = 64; + +// Minimum number of allocations per block. +constexpr int kAllocFit = 4; + +// Base class for MemoryArena that allows (e.g.) MemoryArenaCollection to +// easily manipulate collections of variously sized arenas. +class MemoryArenaBase { + public: + virtual ~MemoryArenaBase() {} + virtual size_t Size() const = 0; +}; + +namespace internal { + +// Allocates 'size' unintialized memory chunks of size object_size from +// underlying blocks of (at least) size 'block_size * object_size'. +// All blocks are freed when this class is deleted. Result of allocate() will +// be aligned to object_size. +template +class MemoryArenaImpl : public MemoryArenaBase { + public: + enum { kObjectSize = object_size }; + + explicit MemoryArenaImpl(size_t block_size = kAllocSize) + : block_size_(block_size * kObjectSize), block_pos_(0) { + blocks_.emplace_front(new char[block_size_]); + } + + void *Allocate(size_t size) { + const auto byte_size = size * kObjectSize; + if (byte_size * kAllocFit > block_size_) { + // Large block; adds new large block. + auto *ptr = new char[byte_size]; + blocks_.emplace_back(ptr); + return ptr; + } + if (block_pos_ + byte_size > block_size_) { + // Doesn't fit; adds new standard block. + auto *ptr = new char[block_size_]; + block_pos_ = 0; + blocks_.emplace_front(ptr); + } + // Fits; uses current block. + auto *ptr = blocks_.front().get() + block_pos_; + block_pos_ += byte_size; + return ptr; + } + + size_t Size() const override { return kObjectSize; } + + private: + const size_t block_size_; // Default block size in bytes. + size_t block_pos_; // Current position in block in bytes. + std::list> blocks_; // List of allocated blocks. +}; + +} // namespace internal + +template +using MemoryArena = internal::MemoryArenaImpl; + +// Base class for MemoryPool that allows (e.g.) MemoryPoolCollection to easily +// manipulate collections of variously sized pools. +class MemoryPoolBase { + public: + virtual ~MemoryPoolBase() {} + virtual size_t Size() const = 0; +}; + +namespace internal { + +// Allocates and frees initially uninitialized memory chunks of size +// object_size. Keeps an internal list of freed chunks that are reused (as is) +// on the next allocation if available. Chunks are constructed in blocks of size +// 'pool_size'. +template +class MemoryPoolImpl : public MemoryPoolBase { + public: + enum { kObjectSize = object_size }; + + struct Link { + char buf[kObjectSize]; + Link *next; + }; + + explicit MemoryPoolImpl(size_t pool_size) + : mem_arena_(pool_size), free_list_(nullptr) {} + + void *Allocate() { + if (free_list_ == nullptr) { + auto *link = static_cast(mem_arena_.Allocate(1)); + link->next = nullptr; + return link; + } else { + auto *link = free_list_; + free_list_ = link->next; + return link; + } + } + + void Free(void *ptr) { + if (ptr) { + auto *link = static_cast(ptr); + link->next = free_list_; + free_list_ = link; + } + } + + size_t Size() const override { return kObjectSize; } + + private: + MemoryArena mem_arena_; + Link *free_list_; + + MemoryPoolImpl(const MemoryPoolImpl &) = delete; + MemoryPoolImpl &operator=(const MemoryPoolImpl &) = delete; +}; + +} // namespace internal + +// Allocates and frees initially uninitialized memory chunks of size sizeof(T). +// All memory is freed when the class is deleted. The result of Allocate() will +// be suitably memory-aligned. Combined with placement operator new and destroy +// functions for the T class, this can be used to improve allocation efficiency. +// See nlp/fst/lib/visit.h (global new) and nlp/fst/lib/dfs-visit.h (class new) +// for examples. +template +class MemoryPool : public internal::MemoryPoolImpl { + public: + // 'pool_size' specifies the size of the initial pool and how it is extended. + MemoryPool(size_t pool_size = kAllocSize) + : internal::MemoryPoolImpl(pool_size) {} +}; + +// Stores a collection of memory arenas. +class MemoryArenaCollection { + public: + // 'block_size' specifies the block size of the arenas. + explicit MemoryArenaCollection(size_t block_size = kAllocSize) + : block_size_(block_size), ref_count_(1) {} + + template + MemoryArena *Arena() { + if (sizeof(T) >= arenas_.size()) arenas_.resize(sizeof(T) + 1); + MemoryArenaBase *arena = arenas_[sizeof(T)].get(); + if (arena == nullptr) { + arena = new MemoryArena(block_size_); + arenas_[sizeof(T)].reset(arena); + } + return static_cast *>(arena); + } + + size_t BlockSize() const { return block_size_; } + + size_t RefCount() const { return ref_count_; } + + size_t IncrRefCount() { return ++ref_count_; } + + size_t DecrRefCount() { return --ref_count_; } + + private: + size_t block_size_; + size_t ref_count_; + std::vector> arenas_; +}; + +// Stores a collection of memory pools +class MemoryPoolCollection { + public: + // 'pool_size' specifies the size of initial pool and how it is extended. + explicit MemoryPoolCollection(size_t pool_size = kAllocSize) + : pool_size_(pool_size), ref_count_(1) {} + + template + MemoryPool *Pool() { + if (sizeof(T) >= pools_.size()) pools_.resize(sizeof(T) + 1); + MemoryPoolBase *pool = pools_[sizeof(T)].get(); + if (pool == nullptr) { + pool = new MemoryPool(pool_size_); + pools_[sizeof(T)].reset(pool); + } + return static_cast *>(pool); + } + + size_t PoolSize() const { return pool_size_; } + + size_t RefCount() const { return ref_count_; } + + size_t IncrRefCount() { return ++ref_count_; } + + size_t DecrRefCount() { return --ref_count_; } + + private: + size_t pool_size_; + size_t ref_count_; + std::vector> pools_; +}; + +// STL allocator using memory arenas. Memory is allocated from underlying +// blocks of size 'block_size * sizeof(T)'. Memory is freed only when all +// objects using this allocator are destroyed and there is otherwise no reuse +// (unlike PoolAllocator). +// +// This allocator has object-local state so it should not be used with splicing +// or swapping operations between objects created with different allocators nor +// should it be used if copies must be thread-safe. The result of allocate() +// will be suitably memory-aligned. +template +class BlockAllocator { + public: + using Allocator = std::allocator; + using size_type = typename Allocator::size_type; + using difference_type = typename Allocator::difference_type; + using pointer = typename Allocator::pointer; + using const_pointer = typename Allocator::const_pointer; + using reference = typename Allocator::reference; + using const_reference = typename Allocator::const_reference; + using value_type = typename Allocator::value_type; + + template + struct rebind { + using other = BlockAllocator; + }; + + explicit BlockAllocator(size_t block_size = kAllocSize) + : arenas_(new MemoryArenaCollection(block_size)) {} + + BlockAllocator(const BlockAllocator &arena_alloc) + : arenas_(arena_alloc.Arenas()) { + Arenas()->IncrRefCount(); + } + + template + explicit BlockAllocator(const BlockAllocator &arena_alloc) + : arenas_(arena_alloc.Arenas()) { + Arenas()->IncrRefCount(); + } + + ~BlockAllocator() { + if (Arenas()->DecrRefCount() == 0) delete Arenas(); + } + + pointer address(reference ref) const { return Allocator().address(ref); } + + const_pointer address(const_reference ref) const { + return Allocator().address(ref); + } + + size_type max_size() const { return Allocator().max_size(); } + + template + void construct(U *p, Args &&... args) { + Allocator().construct(p, std::forward(args)...); + } + + void destroy(pointer p) { Allocator().destroy(p); } + + pointer allocate(size_type n, const void *hint = nullptr) { + if (n * kAllocFit <= kAllocSize) { + return static_cast(Arena()->Allocate(n)); + } else { + return Allocator().allocate(n, hint); + } + } + + void deallocate(pointer p, size_type n) { + if (n * kAllocFit > kAllocSize) Allocator().deallocate(p, n); + } + + MemoryArenaCollection *Arenas() const { return arenas_; } + + private: + MemoryArena *Arena() { return arenas_->Arena(); } + + MemoryArenaCollection *arenas_; + + BlockAllocator operator=(const BlockAllocator &); +}; + +template +bool operator==(const BlockAllocator &alloc1, + const BlockAllocator &alloc2) { + return false; +} + +template +bool operator!=(const BlockAllocator &alloc1, + const BlockAllocator &alloc2) { + return true; +} + +// STL allocator using memory pools. Memory is allocated from underlying +// blocks of size 'block_size * sizeof(T)'. Keeps an internal list of freed +// chunks thare are reused on the next allocation. +// +// This allocator has object-local state so it should not be used with splicing +// or swapping operations between objects created with different allocators nor +// should it be used if copies must be thread-safe. The result of allocate() +// will be suitably memory-aligned. +template +class PoolAllocator { + public: + using Allocator = std::allocator; + using size_type = typename Allocator::size_type; + using difference_type = typename Allocator::difference_type; + using pointer = typename Allocator::pointer; + using const_pointer = typename Allocator::const_pointer; + using reference = typename Allocator::reference; + using const_reference = typename Allocator::const_reference; + using value_type = typename Allocator::value_type; + + template + struct rebind { + using other = PoolAllocator; + }; + + explicit PoolAllocator(size_t pool_size = kAllocSize) + : pools_(new MemoryPoolCollection(pool_size)) {} + + PoolAllocator(const PoolAllocator &pool_alloc) + : pools_(pool_alloc.Pools()) { + Pools()->IncrRefCount(); + } + + template + explicit PoolAllocator(const PoolAllocator &pool_alloc) + : pools_(pool_alloc.Pools()) { + Pools()->IncrRefCount(); + } + + ~PoolAllocator() { + if (Pools()->DecrRefCount() == 0) delete Pools(); + } + + pointer address(reference ref) const { return Allocator().address(ref); } + + const_pointer address(const_reference ref) const { + return Allocator().address(ref); + } + + size_type max_size() const { return Allocator().max_size(); } + + template + void construct(U *p, Args &&... args) { + Allocator().construct(p, std::forward(args)...); + } + + void destroy(pointer p) { Allocator().destroy(p); } + + pointer allocate(size_type n, const void *hint = nullptr) { + if (n == 1) { + return static_cast(Pool<1>()->Allocate()); + } else if (n == 2) { + return static_cast(Pool<2>()->Allocate()); + } else if (n <= 4) { + return static_cast(Pool<4>()->Allocate()); + } else if (n <= 8) { + return static_cast(Pool<8>()->Allocate()); + } else if (n <= 16) { + return static_cast(Pool<16>()->Allocate()); + } else if (n <= 32) { + return static_cast(Pool<32>()->Allocate()); + } else if (n <= 64) { + return static_cast(Pool<64>()->Allocate()); + } else { + return Allocator().allocate(n, hint); + } + } + + void deallocate(pointer p, size_type n) { + if (n == 1) { + Pool<1>()->Free(p); + } else if (n == 2) { + Pool<2>()->Free(p); + } else if (n <= 4) { + Pool<4>()->Free(p); + } else if (n <= 8) { + Pool<8>()->Free(p); + } else if (n <= 16) { + Pool<16>()->Free(p); + } else if (n <= 32) { + Pool<32>()->Free(p); + } else if (n <= 64) { + Pool<64>()->Free(p); + } else { + Allocator().deallocate(p, n); + } + } + + MemoryPoolCollection *Pools() const { return pools_; } + + private: + template + struct TN { + T buf[n]; + }; + + template + MemoryPool> *Pool() { + return pools_->Pool>(); + } + + MemoryPoolCollection *pools_; + + PoolAllocator operator=(const PoolAllocator &); +}; + +template +bool operator==(const PoolAllocator &alloc1, + const PoolAllocator &alloc2) { + return false; +} + +template +bool operator!=(const PoolAllocator &alloc1, + const PoolAllocator &alloc2) { + return true; +} + +} // namespace fst + +#endif // FST_MEMORY_H_ diff --git a/projects/llm_framework/include/fst/minimize.h b/projects/llm_framework/include/fst/minimize.h new file mode 100644 index 00000000..9f17a22a --- /dev/null +++ b/projects/llm_framework/include/fst/minimize.h @@ -0,0 +1,568 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to minimize an FST. + +#ifndef FST_MINIMIZE_H_ +#define FST_MINIMIZE_H_ + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { +namespace internal { + +// Comparator for creating partition. +template +class StateComparator { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateComparator(const Fst &fst, const Partition &partition) + : fst_(fst), partition_(partition) {} + + // Compares state x with state y based on sort criteria. + bool operator()(const StateId x, const StateId y) const { + // Checks for final state equivalence. + const auto xfinal = fst_.Final(x).Hash(); + const auto yfinal = fst_.Final(y).Hash(); + if (xfinal < yfinal) { + return true; + } else if (xfinal > yfinal) { + return false; + } + // Checks for number of arcs. + if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true; + if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false; + // If the number of arcs are equal, checks for arc match. + for (ArcIterator> aiter1(fst_, x), aiter2(fst_, y); + !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) { + const auto &arc1 = aiter1.Value(); + const auto &arc2 = aiter2.Value(); + if (arc1.ilabel < arc2.ilabel) return true; + if (arc1.ilabel > arc2.ilabel) return false; + if (partition_.ClassId(arc1.nextstate) < + partition_.ClassId(arc2.nextstate)) + return true; + if (partition_.ClassId(arc1.nextstate) > + partition_.ClassId(arc2.nextstate)) + return false; + } + return false; + } + + private: + const Fst &fst_; + const Partition &partition_; +}; + +// Computes equivalence classes for cyclic unweighted acceptors. For cyclic +// minimization we use the classic Hopcroft minimization algorithm, which has +// complexity O(E log V) where E is the number of arcs and V is the number of +// states. +// +// For more information, see: +// +// Hopcroft, J. 1971. An n Log n algorithm for minimizing states in a finite +// automaton. Ms, Stanford University. +// +// Note: the original presentation of the paper was for a finite automaton (== +// deterministic, unweighted acceptor), but we also apply it to the +// nondeterministic case, where it is also applicable as long as the semiring is +// idempotent (if the semiring is not idempotent, there are some complexities +// in keeping track of the weight when there are multiple arcs to states that +// will be merged, and we don't deal with this). +template +class CyclicMinimizer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using ClassId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using RevArc = ReverseArc; + + explicit CyclicMinimizer(const ExpandedFst &fst) { + Initialize(fst); + Compute(fst); + } + + const Partition &GetPartition() const { return P_; } + + private: + // StateILabelHasher is a hashing object that computes a hash-function + // of an FST state that depends only on the set of ilabels on arcs leaving + // the state [note: it assumes that the arcs are ilabel-sorted]. + // In order to work correctly for non-deterministic automata, multiple + // instances of the same ilabel count the same as a single instance. + class StateILabelHasher { + public: + explicit StateILabelHasher(const Fst &fst) : fst_(fst) {} + + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + + size_t operator()(const StateId s) { + const size_t p1 = 7603; + const size_t p2 = 433024223; + size_t result = p2; + size_t current_ilabel = kNoLabel; + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + Label this_ilabel = aiter.Value().ilabel; + if (this_ilabel != current_ilabel) { // Ignores repeats. + result = p1 * result + this_ilabel; + current_ilabel = this_ilabel; + } + } + return result; + } + + private: + const Fst &fst_; + }; + + class ArcIterCompare { + public: + explicit ArcIterCompare(const Partition &partition) + : partition_(partition) {} + + ArcIterCompare(const ArcIterCompare &comp) : partition_(comp.partition_) {} + + // Compares two iterators based on their input labels. + bool operator()(const ArcIterator> *x, + const ArcIterator> *y) const { + const auto &xarc = x->Value(); + const auto &yarc = y->Value(); + return xarc.ilabel > yarc.ilabel; + } + + private: + const Partition &partition_; + }; + + using ArcIterQueue = + std::priority_queue> *, + std::vector> *>, + ArcIterCompare>; + + private: + // Prepartitions the space into equivalence classes. We ensure that final and + // non-final states always go into different equivalence classes, and we use + // class StateILabelHasher to make sure that most of the time, states with + // different sets of ilabels on arcs leaving them, go to different partitions. + // Note: for the O(n) guarantees we don't rely on the goodness of this + // hashing function---it just provides a bonus speedup. + void PrePartition(const ExpandedFst &fst) { + VLOG(5) << "PrePartition"; + StateId next_class = 0; + auto num_states = fst.NumStates(); + // Allocates a temporary vector to store the initial class mappings, so that + // we can allocate the classes all at once. + std::vector state_to_initial_class(num_states); + { + // We maintain two maps from hash-value to class---one for final states + // (final-prob == One()) and one for non-final states + // (final-prob == Zero()). We are processing unweighted acceptors, so the + // are the only two possible values. + using HashToClassMap = std::unordered_map; + HashToClassMap hash_to_class_nonfinal; + HashToClassMap hash_to_class_final; + StateILabelHasher hasher(fst); + for (StateId s = 0; s < num_states; ++s) { + size_t hash = hasher(s); + HashToClassMap &this_map = + (fst.Final(s) != Weight::Zero() ? hash_to_class_final + : hash_to_class_nonfinal); + // Avoids two map lookups by using 'insert' instead of 'find'. + auto p = this_map.insert(std::make_pair(hash, next_class)); + state_to_initial_class[s] = p.second ? next_class++ : p.first->second; + } + // Lets the unordered_maps go out of scope before we allocate the classes, + // to reduce the maximum amount of memory used. + } + P_.AllocateClasses(next_class); + for (StateId s = 0; s < num_states; ++s) { + P_.Add(s, state_to_initial_class[s]); + } + for (StateId c = 0; c < next_class; ++c) L_.Enqueue(c); + VLOG(5) << "Initial Partition: " << P_.NumClasses(); + } + + // Creates inverse transition Tr_ = rev(fst), loops over states in FST and + // splits on final, creating two blocks in the partition corresponding to + // final, non-final. + void Initialize(const ExpandedFst &fst) { + // Constructs Tr. + Reverse(fst, &Tr_); + ILabelCompare ilabel_comp; + ArcSort(&Tr_, ilabel_comp); + // Tells the partition how many elements to allocate. The first state in + // Tr_ is super-final state. + P_.Initialize(Tr_.NumStates() - 1); + // Prepares initial partition. + PrePartition(fst); + // Allocates arc iterator queue. + ArcIterCompare comp(P_); + aiter_queue_.reset(new ArcIterQueue(comp)); + } + // Partitions all classes with destination C. + void Split(ClassId C) { + // Prepares priority queue: opens arc iterator for each state in C, and + // inserts into priority queue. + for (PartitionIterator siter(P_, C); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (Tr_.NumArcs(s + 1)) { + aiter_queue_->push(new ArcIterator>(Tr_, s + 1)); + } + } + // Now pops arc iterator from queue, splits entering equivalence class, and + // re-inserts updated iterator into queue. + Label prev_label = -1; + while (!aiter_queue_->empty()) { + std::unique_ptr>> aiter(aiter_queue_->top()); + aiter_queue_->pop(); + if (aiter->Done()) continue; + const auto &arc = aiter->Value(); + auto from_state = aiter->Value().nextstate - 1; + auto from_label = arc.ilabel; + if (prev_label != from_label) P_.FinalizeSplit(&L_); + auto from_class = P_.ClassId(from_state); + if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state); + prev_label = from_label; + aiter->Next(); + if (!aiter->Done()) aiter_queue_->push(aiter.release()); + } + P_.FinalizeSplit(&L_); + } + + // Main loop for Hopcroft minimization. + void Compute(const Fst &fst) { + // Processes active classes (FIFO, or FILO). + while (!L_.Empty()) { + const auto C = L_.Head(); + L_.Dequeue(); + Split(C); // Splits on C, all labels in C. + } + } + + private: + // Partioning of states into equivalence classes. + Partition P_; + // Set of active classes to be processed in partition P. + Queue L_; + // Reverses transition function. + VectorFst Tr_; + // Priority queue of open arc iterators for all states in the splitter + // equivalence class. + std::unique_ptr aiter_queue_; +}; + +// Computes equivalence classes for acyclic FST. +// +// Complexity: +// +// O(E) +// +// where E is the number of arcs. +// +// For more information, see: +// +// Revuz, D. 1992. Minimization of acyclic deterministic automata in linear +// time. Theoretical Computer Science 92(1): 181-189. +template +class AcyclicMinimizer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using ClassId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit AcyclicMinimizer(const ExpandedFst &fst) { + Initialize(fst); + Refine(fst); + } + + const Partition &GetPartition() { return partition_; } + + private: + // DFS visitor to compute the height (distance) to final state. + class HeightVisitor { + public: + HeightVisitor() : max_height_(0), num_states_(0) {} + + // Invoked before DFS visit. + void InitVisit(const Fst &fst) {} + + // Invoked when state is discovered (2nd arg is DFS tree root). + bool InitState(StateId s, StateId root) { + // Extends height array and initialize height (distance) to 0. + for (StateId i = height_.size(); i <= s; ++i) height_.push_back(-1); + if (s >= num_states_) num_states_ = s + 1; + return true; + } + + // Invoked when tree arc examined (to undiscovered state). + bool TreeArc(StateId s, const Arc &arc) { return true; } + + // Invoked when back arc examined (to unfinished state). + bool BackArc(StateId s, const Arc &arc) { return true; } + + // Invoked when forward or cross arc examined (to finished state). + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + if (height_[arc.nextstate] + 1 > height_[s]) { + height_[s] = height_[arc.nextstate] + 1; + } + return true; + } + + // Invoked when state finished (parent is kNoStateId for tree root). + void FinishState(StateId s, StateId parent, const Arc *parent_arc) { + if (height_[s] == -1) height_[s] = 0; + const auto h = height_[s] + 1; + if (parent >= 0) { + if (h > height_[parent]) height_[parent] = h; + if (h > max_height_) max_height_ = h; + } + } + + // Invoked after DFS visit. + void FinishVisit() {} + + size_t max_height() const { return max_height_; } + + const std::vector &height() const { return height_; } + + size_t num_states() const { return num_states_; } + + private: + std::vector height_; + size_t max_height_; + size_t num_states_; + }; + + private: + // Cluster states according to height (distance to final state) + void Initialize(const Fst &fst) { + // Computes height (distance to final state). + HeightVisitor hvisitor; + DfsVisit(fst, &hvisitor); + // Creates initial partition based on height. + partition_.Initialize(hvisitor.num_states()); + partition_.AllocateClasses(hvisitor.max_height() + 1); + const auto &hstates = hvisitor.height(); + for (StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]); + } + + // Refines states based on arc sort (out degree, arc equivalence). + void Refine(const Fst &fst) { + using EquivalenceMap = std::map>; + StateComparator comp(fst, partition_); + // Starts with tail (height = 0). + auto height = partition_.NumClasses(); + for (StateId h = 0; h < height; ++h) { + EquivalenceMap equiv_classes(comp); + // Sorts states within equivalence class. + PartitionIterator siter(partition_, h); + equiv_classes[siter.Value()] = h; + for (siter.Next(); !siter.Done(); siter.Next()) { + auto insert_result = + equiv_classes.insert(std::make_pair(siter.Value(), kNoStateId)); + if (insert_result.second) { + insert_result.first->second = partition_.AddClass(); + } + } + // Creates refined partition. + for (siter.Reset(); !siter.Done();) { + const auto s = siter.Value(); + const auto old_class = partition_.ClassId(s); + const auto new_class = equiv_classes[s]; + // A move operation can invalidate the iterator, so we first update + // the iterator to the next element before we move the current element + // out of the list. + siter.Next(); + if (old_class != new_class) partition_.Move(s, new_class); + } + } + } + + private: + Partition partition_; +}; + +// Given a partition and a Mutable FST, merges states of Fst in place (i.e., +// destructively). Merging works by taking the first state in a class of the +// partition to be the representative state for the class. Each arc is then +// reconnected to this state. All states in the class are merged by adding +// their arcs to the representative state. +template +void MergeStates(const Partition &partition, + MutableFst *fst) { + using StateId = typename Arc::StateId; + std::vector state_map(partition.NumClasses()); + for (StateId i = 0; i < partition.NumClasses(); ++i) { + PartitionIterator siter(partition, i); + state_map[i] = siter.Value(); // First state in partition. + } + // Relabels destination states. + for (StateId c = 0; c < partition.NumClasses(); ++c) { + for (PartitionIterator siter(partition, c); !siter.Done(); + siter.Next()) { + const auto s = siter.Value(); + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate = state_map[partition.ClassId(arc.nextstate)]; + if (s == state_map[c]) { // For the first state, just sets destination. + aiter.SetValue(arc); + } else { + fst->AddArc(state_map[c], std::move(arc)); + } + } + } + } + fst->SetStart(state_map[partition.ClassId(fst->Start())]); + Connect(fst); +} + +template +void AcceptorMinimize(MutableFst *fst, + bool allow_acyclic_minimization = true) { + if (!(fst->Properties(kAcceptor | kUnweighted, true) == + (kAcceptor | kUnweighted))) { + FSTERROR() << "FST is not an unweighted acceptor"; + fst->SetProperties(kError, kError); + return; + } + // Connects FST before minimization, handles disconnected states. + Connect(fst); + if (fst->NumStates() == 0) return; + if (allow_acyclic_minimization && fst->Properties(kAcyclic, true)) { + // Acyclic minimization (Revuz). + VLOG(2) << "Acyclic minimization"; + ArcSort(fst, ILabelCompare()); + AcyclicMinimizer minimizer(*fst); + MergeStates(minimizer.GetPartition(), fst); + } else { + // Either the FST has cycles, or it's generated from non-deterministic input + // (which the Revuz algorithm can't handle), so use the cyclic minimization + // algorithm of Hopcroft. + VLOG(2) << "Cyclic minimization"; + CyclicMinimizer> minimizer(*fst); + MergeStates(minimizer.GetPartition(), fst); + } + // Merges in appropriate semiring + ArcUniqueMapper mapper(*fst); + StateMap(fst, mapper); +} + +} // namespace internal + +// In place minimization of deterministic weighted automata and transducers, +// and also non-deterministic ones if they use an idempotent semiring. +// For transducers, if the 'sfst' argument is not null, the algorithm +// produces a compact factorization of the minimal transducer. +// +// In the acyclic deterministic case, we use an algorithm from Revuz that is +// linear in the number of arcs (edges) in the machine. +// +// In the cyclic or non-deterministic case, we use the classical Hopcroft +// minimization (which was presented for the deterministic case but which +// also works for non-deterministic FSTs); this has complexity O(e log v). +// +template +void Minimize(MutableFst *fst, MutableFst *sfst = nullptr, + float delta = kShortestDelta, bool allow_nondet = false) { + using Weight = typename Arc::Weight; + const auto props = fst->Properties( + kAcceptor | kIDeterministic | kWeighted | kUnweighted, true); + bool allow_acyclic_minimization; + if (props & kIDeterministic) { + allow_acyclic_minimization = true; + } else { + // Our approach to minimization of non-deterministic FSTs will only work in + // idempotent semirings---for non-deterministic inputs, a state could have + // multiple transitions to states that will get merged, and we'd have to + // sum their weights. The algorithm doesn't handle that. + if (!(Weight::Properties() & kIdempotent)) { + fst->SetProperties(kError, kError); + FSTERROR() << "Cannot minimize a non-deterministic FST over a " + "non-idempotent semiring"; + return; + } else if (!allow_nondet) { + fst->SetProperties(kError, kError); + FSTERROR() << "Refusing to minimize a non-deterministic FST with " + << "allow_nondet = false"; + return; + } + // The Revuz algorithm won't work for nondeterministic inputs, so if the + // input is nondeterministic, we'll have to pass a bool saying not to use + // that algorithm. We check at this level rather than in AcceptorMinimize(), + // because it's possible that the FST at this level could be deterministic, + // but a harmless type of non-determinism could be introduced by Encode() + // (thanks to kEncodeWeights, if the FST has epsilons and has a final + // weight with weights equal to some epsilon arc.) + allow_acyclic_minimization = false; + } + if (!(props & kAcceptor)) { // Weighted transducer. + VectorFst> gfst; + ArcMap(*fst, &gfst, ToGallicMapper()); + fst->DeleteStates(); + gfst.SetProperties(kAcceptor, kAcceptor); + Push(&gfst, REWEIGHT_TO_INITIAL, delta); + ArcMap(&gfst, QuantizeMapper>(delta)); + EncodeMapper> encoder( + kEncodeLabels | kEncodeWeights, ENCODE); + Encode(&gfst, &encoder); + internal::AcceptorMinimize(&gfst, allow_acyclic_minimization); + Decode(&gfst, encoder); + if (!sfst) { + FactorWeightFst, + GallicFactor> + fwfst(gfst); + std::unique_ptr osyms( + fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr); + ArcMap(fwfst, fst, FromGallicMapper()); + fst->SetOutputSymbols(osyms.get()); + } else { + sfst->SetOutputSymbols(fst->OutputSymbols()); + GallicToNewSymbolsMapper mapper(sfst); + ArcMap(gfst, fst, &mapper); + fst->SetOutputSymbols(sfst->InputSymbols()); + } + } else if (props & kWeighted) { // Weighted acceptor. + Push(fst, REWEIGHT_TO_INITIAL, delta); + ArcMap(fst, QuantizeMapper(delta)); + EncodeMapper encoder(kEncodeLabels | kEncodeWeights, ENCODE); + Encode(fst, &encoder); + internal::AcceptorMinimize(fst, allow_acyclic_minimization); + Decode(fst, encoder); + } else { // Unweighted acceptor. + internal::AcceptorMinimize(fst, allow_acyclic_minimization); + } +} + +} // namespace fst + +#endif // FST_MINIMIZE_H_ diff --git a/projects/llm_framework/include/fst/mutable-fst.h b/projects/llm_framework/include/fst/mutable-fst.h new file mode 100644 index 00000000..9031770d --- /dev/null +++ b/projects/llm_framework/include/fst/mutable-fst.h @@ -0,0 +1,398 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expanded FST augmented with mutators; interface class definition and +// mutable arc iterator interface. + +#ifndef FST_MUTABLE_FST_H_ +#define FST_MUTABLE_FST_H_ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +template +struct MutableArcIteratorData; + +// Abstract interface for an expanded FST which also supports mutation +// operations. To modify arcs, use MutableArcIterator. +template +class MutableFst : public ExpandedFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual MutableFst &operator=(const Fst &fst) = 0; + + MutableFst &operator=(const MutableFst &fst) { + return operator=(static_cast &>(fst)); + } + + // Sets the initial state. + virtual void SetStart(StateId) = 0; + + // Sets a state's final weight. + virtual void SetFinal(StateId, Weight) = 0; + + // Sets property bits w.r.t. mask. + virtual void SetProperties(uint64 props, uint64 mask) = 0; + + // Adds a state and returns its ID. + virtual StateId AddState() = 0; + + // Adds an arc to state. + virtual void AddArc(StateId, const Arc &arc) = 0; + + // Adds an arc (passed by rvalue reference) to state. Allows subclasses + // to optionally implement move semantics. Defaults to lvalue overload. + virtual void AddArc(StateId state, Arc &&arc) { AddArc(state, arc); } + + // Deletes some states, preserving original StateId ordering. + virtual void DeleteStates(const std::vector &) = 0; + + // Delete all states. + virtual void DeleteStates() = 0; + + // Delete some arcs at a given state. + virtual void DeleteArcs(StateId, size_t n) = 0; + + // Delete all arcs at a given state. + virtual void DeleteArcs(StateId) = 0; + + // Optional, best effort only. + virtual void ReserveStates(StateId n) {} + + // Optional, best effort only. + virtual void ReserveArcs(StateId s, size_t n) {} + + // Returns input label symbol table or nullptr if not specified. + const SymbolTable *InputSymbols() const override = 0; + + // Returns output label symbol table or nullptr if not specified. + const SymbolTable *OutputSymbols() const override = 0; + + // Returns input label symbol table or nullptr if not specified. + virtual SymbolTable *MutableInputSymbols() = 0; + + // Returns output label symbol table or nullptr if not specified. + virtual SymbolTable *MutableOutputSymbols() = 0; + + // Sets input label symbol table; pass nullptr to delete table. + virtual void SetInputSymbols(const SymbolTable *isyms) = 0; + + // Sets output label symbol table; pass nullptr to delete table. + virtual void SetOutputSymbols(const SymbolTable *osyms) = 0; + + // Gets a copy of this MutableFst. See Fst<>::Copy() for further doc. + MutableFst *Copy(bool safe = false) const override = 0; + + // Reads a MutableFst from an input stream, returning nullptr on error. + static MutableFst *Read(std::istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kMutable)) { + LOG(ERROR) << "MutableFst::Read: Not a MutableFst: " << ropts.source; + return nullptr; + } + const auto &fst_type = hdr.FstType(); + const auto reader = FstRegister::GetRegister()->GetReader(fst_type); + if (!reader) { + LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << fst_type + << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source; + return nullptr; + } + auto *fst = reader(strm, ropts); + if (!fst) return nullptr; + return static_cast *>(fst); + } + + // Reads a MutableFst from a file; returns nullptr on error. An empty + // filename results in reading from standard input. If convert is true, + // convert to a mutable FST subclass (given by convert_type) in the case + // that the input FST is non-mutable. + static MutableFst *Read(const string &filename, bool convert = false, + const string &convert_type = "vector") { + if (convert == false) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } else { // Converts to 'convert_type' if not mutable. + std::unique_ptr> ifst(Fst::Read(filename)); + if (!ifst) return nullptr; + if (ifst->Properties(kMutable, false)) { + return static_cast *>(ifst.release()); + } else { + std::unique_ptr> ofst(Convert(*ifst, convert_type)); + ifst.reset(); + if (!ofst) return nullptr; + if (!ofst->Properties(kMutable, false)) { + LOG(ERROR) << "MutableFst: Bad convert type: " << convert_type; + } + return static_cast *>(ofst.release()); + } + } + } + + // For generic mutuble arc iterator construction; not normally called + // directly by users. + virtual void InitMutableArcIterator(StateId s, + MutableArcIteratorData *data) = 0; +}; + +// Mutable arc iterator interface, templated on the Arc definition. This is +// used by mutable arc iterator specializations that are returned by the +// InitMutableArcIterator MutableFst method. +template +class MutableArcIteratorBase : public ArcIteratorBase { + public: + // Sets current arc. + virtual void SetValue(const Arc &) = 0; +}; + +template +struct MutableArcIteratorData { + MutableArcIteratorBase *base; // Specific iterator. +}; + +// Generic mutable arc iterator, templated on the FST definition; a wrapper +// around a pointer to a more specific one. +// +// Here is a typical use: +// +// for (MutableArcIterator aiter(&fst, s); +// !aiter.Done(); +// aiter.Next()) { +// StdArc arc = aiter.Value(); +// arc.ilabel = 7; +// aiter.SetValue(arc); +// ... +// } +// +// This version requires function calls. +template +class MutableArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + MutableArcIterator(FST *fst, StateId s) { + fst->InitMutableArcIterator(s, &data_); + } + + ~MutableArcIterator() { delete data_.base; } + + bool Done() const { return data_.base->Done(); } + + const Arc &Value() const { return data_.base->Value(); } + + void Next() { data_.base->Next(); } + + size_t Position() const { return data_.base->Position(); } + + void Reset() { data_.base->Reset(); } + + void Seek(size_t a) { data_.base->Seek(a); } + + void SetValue(const Arc &arc) { data_.base->SetValue(arc); } + + uint32 Flags() const { return data_.base->Flags(); } + + void SetFlags(uint32 flags, uint32 mask) { + return data_.base->SetFlags(flags, mask); + } + + private: + MutableArcIteratorData data_; + + MutableArcIterator(const MutableArcIterator &) = delete; + MutableArcIterator &operator=(const MutableArcIterator &) = delete; +}; + +namespace internal { + +// MutableFst case: abstract methods. +template +inline typename Arc::Weight Final(const MutableFst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline ssize_t NumArcs(const MutableFst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const MutableFst &fst, + typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const MutableFst &fst, + typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + +// A useful alias when using StdArc. +using StdMutableFst = MutableFst; + +// This is a helper class template useful for attaching a MutableFst interface +// to its implementation, handling reference counting and COW semantics. +template > +class ImplToMutableFst : public ImplToExpandedFst { + public: + using Arc = typename Impl::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ImplToExpandedFst::operator=; + + void SetStart(StateId s) override { + MutateCheck(); + GetMutableImpl()->SetStart(s); + } + + void SetFinal(StateId s, Weight weight) override { + MutateCheck(); + GetMutableImpl()->SetFinal(s, std::move(weight)); + } + + void SetProperties(uint64 props, uint64 mask) override { + // Can skip mutate check if extrinsic properties don't change, + // since it is then safe to update all (shallow) copies + const auto exprops = kExtrinsicProperties & mask; + if (GetImpl()->Properties(exprops) != (props & exprops)) MutateCheck(); + GetMutableImpl()->SetProperties(props, mask); + } + + StateId AddState() override { + MutateCheck(); + return GetMutableImpl()->AddState(); + } + + void AddArc(StateId s, const Arc &arc) override { + MutateCheck(); + GetMutableImpl()->AddArc(s, arc); + } + + void AddArc(StateId s, Arc &&arc) override { + MutateCheck(); + GetMutableImpl()->AddArc(s, std::move(arc)); + } + + void DeleteStates(const std::vector &dstates) override { + MutateCheck(); + GetMutableImpl()->DeleteStates(dstates); + } + + void DeleteStates() override { + if (!Unique()) { + const auto *isymbols = GetImpl()->InputSymbols(); + const auto *osymbols = GetImpl()->OutputSymbols(); + SetImpl(std::make_shared()); + GetMutableImpl()->SetInputSymbols(isymbols); + GetMutableImpl()->SetOutputSymbols(osymbols); + } else { + GetMutableImpl()->DeleteStates(); + } + } + + void DeleteArcs(StateId s, size_t n) override { + MutateCheck(); + GetMutableImpl()->DeleteArcs(s, n); + } + + void DeleteArcs(StateId s) override { + MutateCheck(); + GetMutableImpl()->DeleteArcs(s); + } + + void ReserveStates(StateId s) override { + MutateCheck(); + GetMutableImpl()->ReserveStates(s); + } + + void ReserveArcs(StateId s, size_t n) override { + MutateCheck(); + GetMutableImpl()->ReserveArcs(s, n); + } + + const SymbolTable *InputSymbols() const override { + return GetImpl()->InputSymbols(); + } + + const SymbolTable *OutputSymbols() const override { + return GetImpl()->OutputSymbols(); + } + + SymbolTable *MutableInputSymbols() override { + MutateCheck(); + return GetMutableImpl()->InputSymbols(); + } + + SymbolTable *MutableOutputSymbols() override { + MutateCheck(); + return GetMutableImpl()->OutputSymbols(); + } + + void SetInputSymbols(const SymbolTable *isyms) override { + MutateCheck(); + GetMutableImpl()->SetInputSymbols(isyms); + } + + void SetOutputSymbols(const SymbolTable *osyms) override { + MutateCheck(); + GetMutableImpl()->SetOutputSymbols(osyms); + } + + protected: + using ImplToExpandedFst::GetImpl; + using ImplToExpandedFst::GetMutableImpl; + using ImplToExpandedFst::Unique; + using ImplToExpandedFst::SetImpl; + using ImplToExpandedFst::InputSymbols; + + explicit ImplToMutableFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + ImplToMutableFst(const ImplToMutableFst &fst, bool safe) + : ImplToExpandedFst(fst, safe) {} + + void MutateCheck() { + if (!Unique()) SetImpl(std::make_shared(*this)); + } +}; + +} // namespace fst + +#endif // FST_MUTABLE_FST_H_ diff --git a/projects/llm_framework/include/fst/pair-weight.h b/projects/llm_framework/include/fst/pair-weight.h new file mode 100644 index 00000000..1f2e963c --- /dev/null +++ b/projects/llm_framework/include/fst/pair-weight.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Pair weight templated base class for weight classes that contain two weights +// (e.g. Product, Lexicographic). + +#ifndef FST_PAIR_WEIGHT_H_ +#define FST_PAIR_WEIGHT_H_ + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +template +class PairWeight { + public: + using ReverseWeight = + PairWeight; + + PairWeight() {} + + PairWeight(W1 w1, W2 w2) : value1_(std::move(w1)), value2_(std::move(w2)) {} + + static const PairWeight &Zero() { + static const PairWeight zero(W1::Zero(), W2::Zero()); + return zero; + } + + static const PairWeight &One() { + static const PairWeight one(W1::One(), W2::One()); + return one; + } + + static const PairWeight &NoWeight() { + static const PairWeight no_weight(W1::NoWeight(), W2::NoWeight()); + return no_weight; + } + + std::istream &Read(std::istream &strm) { + value1_.Read(strm); + return value2_.Read(strm); + } + + std::ostream &Write(std::ostream &strm) const { + value1_.Write(strm); + return value2_.Write(strm); + } + + bool Member() const { return value1_.Member() && value2_.Member(); } + + size_t Hash() const { + const auto h1 = value1_.Hash(); + const auto h2 = value2_.Hash(); + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } + + PairWeight Quantize(float delta = kDelta) const { + return PairWeight(value1_.Quantize(delta), value2_.Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(value1_.Reverse(), value2_.Reverse()); + } + + const W1 &Value1() const { return value1_; } + + const W2 &Value2() const { return value2_; } + + void SetValue1(const W1 &weight) { value1_ = weight; } + + void SetValue2(const W2 &weight) { value2_ = weight; } + + private: + W1 value1_; + W2 value2_; +}; + +template +inline bool operator==(const PairWeight &w1, + const PairWeight &w2) { + return w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2(); +} + +template +inline bool operator!=(const PairWeight &w1, + const PairWeight &w2) { + return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2(); +} + +template +inline bool ApproxEqual(const PairWeight &w1, + const PairWeight &w2, float delta = kDelta) { + return ApproxEqual(w1.Value1(), w2.Value1(), delta) && + ApproxEqual(w1.Value2(), w2.Value2(), delta); +} + +template +inline std::ostream &operator<<(std::ostream &strm, + const PairWeight &weight) { + CompositeWeightWriter writer(strm); + writer.WriteBegin(); + writer.WriteElement(weight.Value1()); + writer.WriteElement(weight.Value2()); + writer.WriteEnd(); + return strm; +} + +template +inline std::istream &operator>>(std::istream &strm, + PairWeight &weight) { + CompositeWeightReader reader(strm); + reader.ReadBegin(); + W1 w1; + reader.ReadElement(&w1); + weight.SetValue1(w1); + W2 w2; + reader.ReadElement(&w2, true); + weight.SetValue2(w2); + reader.ReadEnd(); + return strm; +} + +// This function object returns weights by calling the underlying generators +// and forming a pair. This is intended primarily for testing. +template +class WeightGenerate> { + public: + using Weight = PairWeight; + using Generate1 = WeightGenerate; + using Generate2 = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true) + : generate1_(allow_zero), generate2_(allow_zero) {} + + Weight operator()() const { return Weight(generate1_(), generate2_()); } + + private: + Generate1 generate1_; + Generate2 generate2_; +}; + +} // namespace fst + +#endif // FST_PAIR_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/partition.h b/projects/llm_framework/include/fst/partition.h new file mode 100644 index 00000000..5dbbe46a --- /dev/null +++ b/projects/llm_framework/include/fst/partition.h @@ -0,0 +1,305 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to create a partition of states. + +#ifndef FST_PARTITION_H_ +#define FST_PARTITION_H_ + +#include +#include + + +#include + + +namespace fst { +namespace internal { + +template +class PartitionIterator; + +// Defines a partitioning of elements, used to represent equivalence classes +// for FST operations like minimization. T must be a signed integer type. +// +// The elements are numbered from 0 to num_elements - 1. +// Initialize(num_elements) sets up the class for a given number of elements. +// We maintain a partition of these elements into classes. The classes are also +// numbered from zero; you can add a class with AddClass(), or add them in bulk +// with AllocateClasses(num_classes). Initially the elements are not assigned +// to any class; you set up the initial mapping from elements to classes by +// calling Add(element_id, class_id). You can also move an element to a +// different class by calling Move(element_id, class_id). +// +// We also support a rather specialized interface that allows you to efficiently +// split classes in the Hopcroft minimization algorithm. This maintains a +// binary partition of each class. Let's call these, rather arbitrarily, the +// 'yes' subset and the 'no' subset of each class, and assume that by default, +// each element of a class is in its 'no' subset. When one calls +// SplitOn(element_id), element_id is moved to the 'yes' subset of its class. +// (If it was already in the 'yes' set, it just stays there). The aim is to +// enable (later) splitting the class in two in time no greater than the time +// already spent calling SplitOn() for that class. We keep a list of the classes +// which have nonempty 'yes' sets, as visited_classes_. When one calls +// FinalizeSplit(Queue *l), for each class in visited_classes_ whose 'yes' +// and 'no' sets are both nonempty, it will create a new class consisting of +// the smaller of the two subsets (and this class will be added to the queue), +// and the old class will now be the larger of the two subsets. This call also +// resets all the yes/no partitions so that everything is in the 'no' subsets. +// +// One cannot use the Move() function if SplitOn() has been called without +// a subsequent call to FinalizeSplit() +template +class Partition { + public: + Partition() {} + + explicit Partition(T num_elements) { Initialize(num_elements); } + + // Creates an empty partition for num_elements. This means that the elements + // are not assigned to a class (i.e class_index = -1); you should set up the + // number of classes using AllocateClasses() or AddClass(), and allocate each + // element to a class by calling Add(element, class_id). + void Initialize(size_t num_elements) { + elements_.resize(num_elements); + classes_.reserve(num_elements); + classes_.clear(); + yes_counter_ = 1; + } + + // Adds a class; returns new number of classes. + T AddClass() { + auto num_classes = classes_.size(); + classes_.resize(num_classes + 1); + return num_classes; + } + + // Adds 'num_classes' new (empty) classes. + void AllocateClasses(T num_classes) { + classes_.resize(classes_.size() + num_classes); + } + + // Adds element_id to class_id. element_id should already have been allocated + // by calling Initialize(num_elements)---or the constructor taking + // num_elements---with num_elements > element_id. element_id must not + // currently be a member of any class; once elements have been added to a + // class, use the Move() method to move them from one class to another. + void Add(T element_id, T class_id) { + auto &this_element = elements_[element_id]; + auto &this_class = classes_[class_id]; + ++this_class.size; + // Adds the element to the 'no' subset of the class. + auto no_head = this_class.no_head; + if (no_head >= 0) elements_[no_head].prev_element = element_id; + this_class.no_head = element_id; + this_element.class_id = class_id; + // Adds to the 'no' subset of the class. + this_element.yes = 0; + this_element.next_element = no_head; + this_element.prev_element = -1; + } + + // Moves element_id from 'no' subset of its current class to 'no' subset of + // class class_id. This may not work correctly if you have called SplitOn() + // [for any element] and haven't subsequently called FinalizeSplit(). + void Move(T element_id, T class_id) { + auto elements = &(elements_[0]); + auto &element = elements[element_id]; + auto &old_class = classes_[element.class_id]; + --old_class.size; + // Excises the element from the 'no' list of its old class, where it is + // assumed to be. + if (element.prev_element >= 0) { + elements[element.prev_element].next_element = element.next_element; + } else { + old_class.no_head = element.next_element; + } + if (element.next_element >= 0) { + elements[element.next_element].prev_element = element.prev_element; + } + // Adds to new class. + Add(element_id, class_id); + } + + // Moves element_id to the 'yes' subset of its class if it was in the 'no' + // subset, and marks the class as having been visited. + void SplitOn(T element_id) { + auto elements = &(elements_[0]); + auto &element = elements[element_id]; + if (element.yes == yes_counter_) { + return; // Already in the 'yes' set; nothing to do. + } + auto class_id = element.class_id; + auto &this_class = classes_[class_id]; + // Excises the element from the 'no' list of its class. + if (element.prev_element >= 0) { + elements[element.prev_element].next_element = element.next_element; + } else { + this_class.no_head = element.next_element; + } + if (element.next_element >= 0) { + elements[element.next_element].prev_element = element.prev_element; + } + // Adds the element to the 'yes' list. + if (this_class.yes_head >= 0) { + elements[this_class.yes_head].prev_element = element_id; + } else { + visited_classes_.push_back(class_id); + } + element.yes = yes_counter_; + element.next_element = this_class.yes_head; + element.prev_element = -1; + this_class.yes_head = element_id; + this_class.yes_size++; + } + + // This should be called after one has possibly called SplitOn for one or more + // elements, thus moving those elements to the 'yes' subset for their class. + // For each class that has a nontrivial split (i.e., it's not the case that + // all members are in the 'yes' or 'no' subset), this function creates a new + // class containing the smaller of the two subsets of elements, leaving the + // larger group of elements in the old class. The identifier of the new class + // will be added to the queue provided as the pointer L. This method then + // moves all elements to the 'no' subset of their class. + template + void FinalizeSplit(Queue *queue) { + for (const auto &visited_class : visited_classes_) { + const auto new_class = SplitRefine(visited_class); + if (new_class != -1 && queue) queue->Enqueue(new_class); + } + visited_classes_.clear(); + // Incrementation sets all the 'yes' members of the elements to false. + ++yes_counter_; + } + + const T ClassId(T element_id) const { return elements_[element_id].class_id; } + + const size_t ClassSize(T class_id) const { return classes_[class_id].size; } + + const T NumClasses() const { return classes_.size(); } + + private: + friend class PartitionIterator; + + // Information about a given element. + struct Element { + T class_id; // Class ID of this element. + T yes; // This is to be interpreted as a bool, true if it's in the + // 'yes' set of this class. The interpretation as bool is + // (yes == yes_counter_ ? true : false). + T next_element; // Next element in the 'no' list or 'yes' list of this + // class, whichever of the two we belong to (think of + // this as the 'next' in a doubly-linked list, although + // it is an index into the elements array). Negative + // values corresponds to null. + T prev_element; // Previous element in the 'no' or 'yes' doubly linked + // list. Negative values corresponds to null. + }; + + // Information about a given class. + struct Class { + Class() : size(0), yes_size(0), no_head(-1), yes_head(-1) {} + T size; // Total number of elements in this class ('no' plus 'yes' + // subsets). + T yes_size; // Total number of elements of 'yes' subset of this class. + T no_head; // Index of head element of doubly-linked list in 'no' subset. + // Everything is in the 'no' subset until you call SplitOn(). + // -1 means no element. + T yes_head; // Index of head element of doubly-linked list in 'yes' subset. + // -1 means no element. + }; + + // This method, called from FinalizeSplit(), checks whether a class has to + // be split (a class will be split only if its 'yes' and 'no' subsets are + // both nonempty, but one can assume that since this function was called, the + // 'yes' subset is nonempty). It splits by taking the smaller subset and + // making it a new class, and leaving the larger subset of elements in the + // 'no' subset of the old class. It returns the new class if created, or -1 + // if none was created. + T SplitRefine(T class_id) { + auto yes_size = classes_[class_id].yes_size; + auto size = classes_[class_id].size; + auto no_size = size - yes_size; + if (no_size == 0) { + // All members are in the 'yes' subset, so we don't have to create a new + // class, just move them all to the 'no' subset. + classes_[class_id].no_head = classes_[class_id].yes_head; + classes_[class_id].yes_head = -1; + classes_[class_id].yes_size = 0; + return -1; + } else { + auto new_class_id = classes_.size(); + classes_.resize(classes_.size() + 1); + auto &old_class = classes_[class_id]; + auto &new_class = classes_[new_class_id]; + // The new_class will have the values from the constructor. + if (no_size < yes_size) { + // Moves the 'no' subset to new class ('no' subset). + new_class.no_head = old_class.no_head; + new_class.size = no_size; + // And makes the 'yes' subset of the old class ('no' subset). + old_class.no_head = old_class.yes_head; + old_class.yes_head = -1; + old_class.size = yes_size; + old_class.yes_size = 0; + } else { + // Moves the 'yes' subset to the new class (to the 'no' subset) + new_class.size = yes_size; + new_class.no_head = old_class.yes_head; + // Retains only the 'no' subset in the old class. + old_class.size = no_size; + old_class.yes_size = 0; + old_class.yes_head = -1; + } + auto elements = &(elements_[0]); + // Updates the 'class_id' of all the elements we moved. + for (auto e = new_class.no_head; e >= 0; e = elements[e].next_element) { + elements[e].class_id = new_class_id; + } + return new_class_id; + } + } + + // elements_[i] contains all info about the i'th element. + std::vector elements_; + // classes_[i] contains all info about the i'th class. + std::vector classes_; + // Set of visited classes to be used in split refine. + std::vector visited_classes_; + // yes_counter_ is used in interpreting the 'yes' members of class Element. + // If element.yes == yes_counter_, we interpret that element as being in the + // 'yes' subset of its class. This allows us to, in effect, set all those + // bools to false at a stroke by incrementing yes_counter_. + T yes_counter_; +}; + +// Iterates over members of the 'no' subset of a class in a partition. (When +// this is used, everything is in the 'no' subset). +template +class PartitionIterator { + public: + using Element = typename Partition::Element; + + PartitionIterator(const Partition &partition, T class_id) + : partition_(partition), + element_id_(partition_.classes_[class_id].no_head), + class_id_(class_id) {} + + bool Done() { return element_id_ < 0; } + + const T Value() { return element_id_; } + + void Next() { element_id_ = partition_.elements_[element_id_].next_element; } + + void Reset() { element_id_ = partition_.classes_[class_id_].no_head; } + + private: + const Partition &partition_; + T element_id_; + T class_id_; +}; + +} // namespace internal +} // namespace fst + +#endif // FST_PARTITION_H_ diff --git a/projects/llm_framework/include/fst/power-weight.h b/projects/llm_framework/include/fst/power-weight.h new file mode 100644 index 00000000..f2f3cbdb --- /dev/null +++ b/projects/llm_framework/include/fst/power-weight.h @@ -0,0 +1,168 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Cartesian power weight semiring operation definitions. + +#ifndef FST_POWER_WEIGHT_H_ +#define FST_POWER_WEIGHT_H_ + +#include + +#include +#include + + +namespace fst { + +// Cartesian power semiring: W ^ n +// +// Forms: +// - a left semimodule when W is a left semiring, +// - a right semimodule when W is a right semiring, +// - a bisemimodule when W is a semiring, +// the free semimodule of rank n over W +// The Times operation is overloaded to provide the left and right scalar +// products. +template +class PowerWeight : public TupleWeight { + public: + using ReverseWeight = PowerWeight; + + PowerWeight() {} + + explicit PowerWeight(const TupleWeight &weight) + : TupleWeight(weight) {} + + template + PowerWeight(Iterator begin, Iterator end) : TupleWeight(begin, end) {} + + // Initialize component `index` to `weight`; initialize all other components + // to `default_weight` + PowerWeight(size_t index, const W &weight, + const W &default_weight = W::Zero()) + : TupleWeight(index, weight, default_weight) {} + + static const PowerWeight &Zero() { + static const PowerWeight zero(TupleWeight::Zero()); + return zero; + } + + static const PowerWeight &One() { + static const PowerWeight one(TupleWeight::One()); + return one; + } + + static const PowerWeight &NoWeight() { + static const PowerWeight no_weight(TupleWeight::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string(W::Type() + "_^" + std::to_string(n)); + return *type; + } + + static constexpr uint64 Properties() { + return W::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + PowerWeight Quantize(float delta = kDelta) const { + return PowerWeight(TupleWeight::Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(TupleWeight::Reverse()); + } +}; + +// Semiring plus operation. +template +inline PowerWeight Plus(const PowerWeight &w1, + const PowerWeight &w2) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Plus(w1.Value(i), w2.Value(i))); + } + return result; +} + +// Semiring times operation. +template +inline PowerWeight Times(const PowerWeight &w1, + const PowerWeight &w2) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(w1.Value(i), w2.Value(i))); + } + return result; +} + +// Semiring divide operation. +template +inline PowerWeight Divide(const PowerWeight &w1, + const PowerWeight &w2, + DivideType type = DIVIDE_ANY) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Divide(w1.Value(i), w2.Value(i), type)); + } + return result; +} + +// Semimodule left scalar product. +template +inline PowerWeight Times(const W &scalar, + const PowerWeight &weight) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(scalar, weight.Value(i))); + } + return result; +} + +// Semimodule right scalar product. +template +inline PowerWeight Times(const PowerWeight &weight, + const W &scalar) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(weight.Value(i), scalar)); + } + return result; +} + +// Semimodule dot product. +template +inline W DotProduct(const PowerWeight &w1, const PowerWeight &w2) { + W result(W::Zero()); + for (size_t i = 0; i < n; ++i) { + result = Plus(result, Times(w1.Value(i), w2.Value(i))); + } + return result; +} + +// This function object generates weights over the Cartesian power of rank +// n over the underlying weight. This is intended primarily for testing. +template +class WeightGenerate> { + public: + using Weight = PowerWeight; + using Generate = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true) : generate_(allow_zero) {} + + Weight operator()() const { + Weight result; + for (size_t i = 0; i < n; ++i) result.SetValue(i, generate_()); + return result; + } + + private: + Generate generate_; +}; + +} // namespace fst + +#endif // FST_POWER_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/product-weight.h b/projects/llm_framework/include/fst/product-weight.h new file mode 100644 index 00000000..56a18be1 --- /dev/null +++ b/projects/llm_framework/include/fst/product-weight.h @@ -0,0 +1,107 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Product weight set and associated semiring operation definitions. + +#ifndef FST_PRODUCT_WEIGHT_H_ +#define FST_PRODUCT_WEIGHT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Product semiring: W1 * W2. +template +class ProductWeight : public PairWeight { + public: + using ReverseWeight = + ProductWeight; + + ProductWeight() {} + + explicit ProductWeight(const PairWeight &weight) + : PairWeight(weight) {} + + ProductWeight(W1 w1, W2 w2) + : PairWeight(std::move(w1), std::move(w2)) {} + + static const ProductWeight &Zero() { + static const ProductWeight zero(PairWeight::Zero()); + return zero; + } + + static const ProductWeight &One() { + static const ProductWeight one(PairWeight::One()); + return one; + } + + static const ProductWeight &NoWeight() { + static const ProductWeight no_weight(PairWeight::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string(W1::Type() + "_X_" + W2::Type()); + return *type; + } + + static constexpr uint64 Properties() { + return W1::Properties() & W2::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + ProductWeight Quantize(float delta = kDelta) const { + return ProductWeight(PairWeight::Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(PairWeight::Reverse()); + } +}; + +template +inline ProductWeight Plus(const ProductWeight &w1, + const ProductWeight &w2) { + return ProductWeight(Plus(w1.Value1(), w2.Value1()), + Plus(w1.Value2(), w2.Value2())); +} + +template +inline ProductWeight Times(const ProductWeight &w1, + const ProductWeight &w2) { + return ProductWeight(Times(w1.Value1(), w2.Value1()), + Times(w1.Value2(), w2.Value2())); +} + +template +inline ProductWeight Divide(const ProductWeight &w1, + const ProductWeight &w2, + DivideType typ = DIVIDE_ANY) { + return ProductWeight(Divide(w1.Value1(), w2.Value1(), typ), + Divide(w1.Value2(), w2.Value2(), typ)); +} + +// This function object generates weights by calling the underlying generators +// for the template weight types, like all other pair weight types. This is +// intended primarily for testing. +template +class WeightGenerate> : + public WeightGenerate> { + public: + using Weight = ProductWeight; + using Generate = WeightGenerate>; + + explicit WeightGenerate(bool allow_zero = true) : Generate(allow_zero) {} + + Weight operator()() const { return Weight(Generate::operator()()); } +}; + +} // namespace fst + +#endif // FST_PRODUCT_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/project.h b/projects/llm_framework/include/fst/project.h new file mode 100644 index 00000000..5a82cf14 --- /dev/null +++ b/projects/llm_framework/include/fst/project.h @@ -0,0 +1,159 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to project an FST on to its domain or range. + +#ifndef FST_PROJECT_H_ +#define FST_PROJECT_H_ + +#include +#include + + +namespace fst { + +// This specifies whether to project on input or output. +enum ProjectType { PROJECT_INPUT = 1, PROJECT_OUTPUT = 2 }; + +// Mapper to implement projection per arc. +template +class ProjectMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr explicit ProjectMapper(ProjectType project_type) + : project_type_(project_type) {} + + ToArc operator()(const FromArc &arc) const { + const auto label = project_type_ == PROJECT_INPUT ? arc.ilabel : arc.olabel; + return ToArc(label, label, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { + return MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return project_type_ == PROJECT_INPUT ? MAP_COPY_SYMBOLS + : MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return project_type_ == PROJECT_OUTPUT ? MAP_COPY_SYMBOLS + : MAP_CLEAR_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return ProjectProperties(props, project_type_ == PROJECT_INPUT); + } + + private: + const ProjectType project_type_; +}; + +// Projects an FST onto its domain or range by either copying each arcs' input +// label to the output label or vice versa. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(1) +// +// where V is the number of states and E is the number of arcs. +template +inline void Project(const Fst &ifst, MutableFst *ofst, + ProjectType project_type) { + ArcMap(ifst, ofst, ProjectMapper(project_type)); + switch (project_type) { + case PROJECT_INPUT: + ofst->SetOutputSymbols(ifst.InputSymbols()); + return; + case PROJECT_OUTPUT: + ofst->SetInputSymbols(ifst.OutputSymbols()); + return; + } +} + +// Destructive variant of the above. +template +inline void Project(MutableFst *fst, ProjectType project_type) { + ArcMap(fst, ProjectMapper(project_type)); + switch (project_type) { + case PROJECT_INPUT: + fst->SetOutputSymbols(fst->InputSymbols()); + return; + case PROJECT_OUTPUT: + fst->SetInputSymbols(fst->OutputSymbols()); + return; + } +} + +// Projects an FST onto its domain or range by either copying each arc's input +// label to the output label or vice versa. This version is a delayed FST. +// +// Complexity: +// +// Time: O(v + e) +// Space: O(1) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time and to visit an input state or arc is assumed and exclusive of +// caching. +template +class ProjectFst : public ArcMapFst> { + public: + using FromArc = A; + using ToArc = A; + + using Impl = internal::ArcMapFstImpl>; + + ProjectFst(const Fst &fst, ProjectType project_type) + : ArcMapFst>(fst, ProjectMapper(project_type)) { + if (project_type == PROJECT_INPUT) { + GetMutableImpl()->SetOutputSymbols(fst.InputSymbols()); + } + if (project_type == PROJECT_OUTPUT) { + GetMutableImpl()->SetInputSymbols(fst.OutputSymbols()); + } + } + + // See Fst<>::Copy() for doc. + ProjectFst(const ProjectFst &fst, bool safe = false) + : ArcMapFst>(fst, safe) {} + + // Gets a copy of this ProjectFst. See Fst<>::Copy() for further doc. + ProjectFst *Copy(bool safe = false) const override { + return new ProjectFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for ProjectFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const ProjectFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for ProjectFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + using StateId = typename A::StateId; + + ArcIterator(const ProjectFst &fst, StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdProjectFst = ProjectFst; + +} // namespace fst + +#endif // FST_PROJECT_H_ diff --git a/projects/llm_framework/include/fst/properties.h b/projects/llm_framework/include/fst/properties.h new file mode 100644 index 00000000..157247a6 --- /dev/null +++ b/projects/llm_framework/include/fst/properties.h @@ -0,0 +1,468 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST property bits. + +#ifndef FST_PROPERTIES_H_ +#define FST_PROPERTIES_H_ + +#include +#include + +#include + +namespace fst { + +// The property bits here assert facts about an FST. If individual bits are +// added, then the composite properties below, the property functions and +// property names in properties.cc, and TestProperties() in test-properties.h +// should be updated. + +// BINARY PROPERTIES +// +// For each property below, there is a single bit. If it is set, the property is +// true. If it is not set, the property is false. + +// The Fst is an ExpandedFst. +constexpr uint64 kExpanded = 0x0000000000000001ULL; + +// The Fst is a MutableFst. +constexpr uint64 kMutable = 0x0000000000000002ULL; + +// An error was detected while constructing/using the FST. +constexpr uint64 kError = 0x0000000000000004ULL; + +// TRINARY PROPERTIES +// +// For each of these properties below there is a pair of property bits, one +// positive and one negative. If the positive bit is set, the property is true. +// If the negative bit is set, the property is false. If neither is set, the +// property has unknown value. Both should never be simultaneously set. The +// individual positive and negative bit pairs should be adjacent with the +// positive bit at an odd and lower position. + +// ilabel == olabel for each arc. +constexpr uint64 kAcceptor = 0x0000000000010000ULL; +// ilabel != olabel for some arc. +constexpr uint64 kNotAcceptor = 0x0000000000020000ULL; + +// ilabels unique leaving each state. +constexpr uint64 kIDeterministic = 0x0000000000040000ULL; +// ilabels not unique leaving some state. +constexpr uint64 kNonIDeterministic = 0x0000000000080000ULL; + +// olabels unique leaving each state. +constexpr uint64 kODeterministic = 0x0000000000100000ULL; +// olabels not unique leaving some state. +constexpr uint64 kNonODeterministic = 0x0000000000200000ULL; + +// FST has input/output epsilons. +constexpr uint64 kEpsilons = 0x0000000000400000ULL; +// FST has no input/output epsilons. +constexpr uint64 kNoEpsilons = 0x0000000000800000ULL; + +// FST has input epsilons. +constexpr uint64 kIEpsilons = 0x0000000001000000ULL; +// FST has no input epsilons. +constexpr uint64 kNoIEpsilons = 0x0000000002000000ULL; + +// FST has output epsilons. +constexpr uint64 kOEpsilons = 0x0000000004000000ULL; +// FST has no output epsilons. +constexpr uint64 kNoOEpsilons = 0x0000000008000000ULL; + +// ilabels sorted wrt < for each state. +constexpr uint64 kILabelSorted = 0x0000000010000000ULL; +// ilabels not sorted wrt < for some state. +constexpr uint64 kNotILabelSorted = 0x0000000020000000ULL; + +// olabels sorted wrt < for each state. +constexpr uint64 kOLabelSorted = 0x0000000040000000ULL; +// olabels not sorted wrt < for some state. +constexpr uint64 kNotOLabelSorted = 0x0000000080000000ULL; + +// Non-trivial arc or final weights. +constexpr uint64 kWeighted = 0x0000000100000000ULL; +// Only trivial arc and final weights. +constexpr uint64 kUnweighted = 0x0000000200000000ULL; + +// FST has cycles. +constexpr uint64 kCyclic = 0x0000000400000000ULL; +// FST has no cycles. +constexpr uint64 kAcyclic = 0x0000000800000000ULL; + +// FST has cycles containing the initial state. +constexpr uint64 kInitialCyclic = 0x0000001000000000ULL; +// FST has no cycles containing the initial state. +constexpr uint64 kInitialAcyclic = 0x0000002000000000ULL; + +// FST is topologically sorted. +constexpr uint64 kTopSorted = 0x0000004000000000ULL; +// FST is not topologically sorted. +constexpr uint64 kNotTopSorted = 0x0000008000000000ULL; + +// All states reachable from the initial state. +constexpr uint64 kAccessible = 0x0000010000000000ULL; +// Not all states reachable from the initial state. +constexpr uint64 kNotAccessible = 0x0000020000000000ULL; + +// All states can reach a final state. +constexpr uint64 kCoAccessible = 0x0000040000000000ULL; +// Not all states can reach a final state. +constexpr uint64 kNotCoAccessible = 0x0000080000000000ULL; + +// If NumStates() > 0, then state 0 is initial, state NumStates() - 1 is final, +// there is a transition from each non-final state i to state i + 1, and there +// are no other transitions. +constexpr uint64 kString = 0x0000100000000000ULL; + +// Not a string FST. +constexpr uint64 kNotString = 0x0000200000000000ULL; + +// FST has least one weighted cycle. +constexpr uint64 kWeightedCycles = 0x0000400000000000ULL; + +// Only unweighted cycles. +constexpr uint64 kUnweightedCycles = 0x0000800000000000ULL; + +// COMPOSITE PROPERTIES + +// Properties of an empty machine. +constexpr uint64 kNullProperties = + kAcceptor | kIDeterministic | kODeterministic | kNoEpsilons | kNoIEpsilons | + kNoOEpsilons | kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | + kInitialAcyclic | kTopSorted | kAccessible | kCoAccessible | kString | + kUnweightedCycles; + +// Properties that are preserved when an FST is copied. +constexpr uint64 kCopyProperties = + kError | kAcceptor | kNotAcceptor | kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | kEpsilons | kNoEpsilons | + kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are intrinsic to the FST. +constexpr uint64 kIntrinsicProperties = + kExpanded | kMutable | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are (potentially) extrinsic to the FST. +constexpr uint64 kExtrinsicProperties = kError; + +// Properties that are preserved when an FST start state is set. +constexpr uint64 kSetStartProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kTopSorted | kNotTopSorted | + kCoAccessible | kNotCoAccessible | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST final weight is set. +constexpr uint64 kSetFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | + kNotTopSorted | kAccessible | kNotAccessible | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST state is added. +constexpr uint64 kAddStateProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kNotAccessible | + kNotCoAccessible | kNotString | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST arc is added. +constexpr uint64 kAddArcProperties = + kExpanded | kMutable | kError | kNotAcceptor | kNonIDeterministic | + kNonODeterministic | kEpsilons | kIEpsilons | kOEpsilons | + kNotILabelSorted | kNotOLabelSorted | kWeighted | kCyclic | kInitialCyclic | + kNotTopSorted | kAccessible | kCoAccessible | kWeightedCycles; + +// Properties that are preserved when an FST arc is set. +constexpr uint64 kSetArcProperties = kExpanded | kMutable | kError; + +// Properties that are preserved when FST states are deleted. +constexpr uint64 kDeleteStatesProperties = + kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | kInitialAcyclic | + kTopSorted | kUnweightedCycles; + +// Properties that are preserved when FST arcs are deleted. +constexpr uint64 kDeleteArcsProperties = + kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | kInitialAcyclic | + kTopSorted | kNotAccessible | kNotCoAccessible | kUnweightedCycles; + +// Properties that are preserved when an FST's states are reordered. +constexpr uint64 kStateSortProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST's arcs are reordered. +constexpr uint64 kArcSortProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST's input labels are changed. +constexpr uint64 kILabelInvariantProperties = + kExpanded | kMutable | kError | kODeterministic | kNonODeterministic | + kOEpsilons | kNoOEpsilons | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST's output labels are changed. +constexpr uint64 kOLabelInvariantProperties = + kExpanded | kMutable | kError | kIDeterministic | kNonIDeterministic | + kIEpsilons | kNoIEpsilons | kILabelSorted | kNotILabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST's weights are changed. This +// assumes that the set of states that are non-final is not changed. +constexpr uint64 kWeightInvariantProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | + kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString; + +// Properties that are preserved when a superfinal state is added and an FST's +// final weights are directed to it via new transitions. +constexpr uint64 kAddSuperFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | + kNonIDeterministic | kNonODeterministic | kEpsilons | kIEpsilons | + kOEpsilons | kNotILabelSorted | kNotOLabelSorted | kWeighted | kUnweighted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kNotTopSorted | + kNotAccessible | kCoAccessible | kNotCoAccessible | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when a superfinal state is removed and the +// epsilon transitions directed to it are made final weights. +constexpr uint64 kRmSuperFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kWeighted | kUnweighted | kCyclic | + kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | kAccessible | + kCoAccessible | kNotCoAccessible | kString | kWeightedCycles | + kUnweightedCycles; + +// All binary properties. +constexpr uint64 kBinaryProperties = 0x0000000000000007ULL; + +// All trinary properties. +constexpr uint64 kTrinaryProperties = 0x0000ffffffff0000ULL; + +// COMPUTED PROPERTIES + +// 1st bit of trinary properties. +constexpr uint64 kPosTrinaryProperties = kTrinaryProperties & + 0x5555555555555555ULL; + +// 2nd bit of trinary properties. +constexpr uint64 kNegTrinaryProperties = kTrinaryProperties & + 0xaaaaaaaaaaaaaaaaULL; + +// All properties. +constexpr uint64 kFstProperties = kBinaryProperties | kTrinaryProperties; + +// PROPERTY FUNCTIONS and STRING NAMES (defined in properties.cc). + +// Below are functions for getting property bit vectors when executing +// mutation operations. +inline uint64 SetStartProperties(uint64 inprops); + +template +uint64 SetFinalProperties(uint64 inprops, const Weight &old_weight, + const Weight &new_weight); + +inline uint64 AddStateProperties(uint64 inprops); + +template +uint64 AddArcProperties(uint64 inprops, typename A::StateId s, const A &arc, + const A *prev_arc); + +inline uint64 DeleteStatesProperties(uint64 inprops); + +inline uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticProps); + +inline uint64 DeleteArcsProperties(uint64 inprops); + +uint64 ClosureProperties(uint64 inprops, bool star, bool delayed = false); + +uint64 ComplementProperties(uint64 inprops); + +uint64 ComposeProperties(uint64 inprops1, uint64 inprops2); + +uint64 ConcatProperties(uint64 inprops1, uint64 inprops2, bool delayed = false); + +uint64 DeterminizeProperties(uint64 inprops, bool has_subsequential_label, + bool distinct_psubsequential_labels); + +uint64 FactorWeightProperties(uint64 inprops); + +uint64 InvertProperties(uint64 inprops); + +uint64 ProjectProperties(uint64 inprops, bool project_input); + +uint64 RandGenProperties(uint64 inprops, bool weighted); + +uint64 RelabelProperties(uint64 inprops); + +uint64 ReplaceProperties(const std::vector &inprops, size_t root, + bool epsilon_on_call, bool epsilon_on_return, + bool out_epsilon_on_call, bool out_epsilon_on_return, + bool replace_transducer, bool no_empty_fst, + bool all_ilabel_sorted, bool all_olabel_sorted, + bool all_negative_or_dense); + +uint64 ReverseProperties(uint64 inprops, bool has_superinitial); + +uint64 ReweightProperties(uint64 inprops); + +uint64 RmEpsilonProperties(uint64 inprops, bool delayed = false); + +uint64 ShortestPathProperties(uint64 props, bool tree = false); + +uint64 SynchronizeProperties(uint64 inprops); + +uint64 UnionProperties(uint64 inprops1, uint64 inprops2, bool delayed = false); + +// Definitions of inlined functions. + +uint64 SetStartProperties(uint64 inprops) { + auto outprops = inprops & kSetStartProperties; + if (inprops & kAcyclic) { + outprops |= kInitialAcyclic; + } + return outprops; +} + +uint64 AddStateProperties(uint64 inprops) { + return inprops & kAddStateProperties; +} + +uint64 DeleteStatesProperties(uint64 inprops) { + return inprops & kDeleteStatesProperties; +} + +uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticprops) { + const auto outprops = inprops & kError; + return outprops | kNullProperties | staticprops; +} + +uint64 DeleteArcsProperties(uint64 inprops) { + return inprops & kDeleteArcsProperties; +} + +// Definitions of template functions. + +template +uint64 SetFinalProperties(uint64 inprops, const Weight &old_weight, + const Weight &new_weight) { + auto outprops = inprops; + if (old_weight != Weight::Zero() && old_weight != Weight::One()) { + outprops &= ~kWeighted; + } + if (new_weight != Weight::Zero() && new_weight != Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + outprops &= kSetFinalProperties | kWeighted | kUnweighted; + return outprops; +} + +/// Gets the properties for the MutableFst::AddArc method. +/// +/// \param inprops the current properties of the FST +/// \param s the ID of the state to which an arc is being added. +/// \param arc the arc being added to the state with the specified ID +/// \param prev_arc the previously-added (or "last") arc of state s, or nullptr +// if s currently has no arcs. +template +uint64 AddArcProperties(uint64 inprops, typename Arc::StateId s, + const Arc &arc, const Arc *prev_arc) { + using Weight = typename Arc::Weight; + auto outprops = inprops; + if (arc.ilabel != arc.olabel) { + outprops |= kNotAcceptor; + outprops &= ~kAcceptor; + } + if (arc.ilabel == 0) { + outprops |= kIEpsilons; + outprops &= ~kNoIEpsilons; + if (arc.olabel == 0) { + outprops |= kEpsilons; + outprops &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + outprops |= kOEpsilons; + outprops &= ~kNoOEpsilons; + } + if (prev_arc) { + if (prev_arc->ilabel > arc.ilabel) { + outprops |= kNotILabelSorted; + outprops &= ~kILabelSorted; + } + if (prev_arc->olabel > arc.olabel) { + outprops |= kNotOLabelSorted; + outprops &= ~kOLabelSorted; + } + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + if (arc.nextstate <= s) { + outprops |= kNotTopSorted; + outprops &= ~kTopSorted; + } + outprops &= kAddArcProperties | kAcceptor | kNoEpsilons | kNoIEpsilons | + kNoOEpsilons | kILabelSorted | kOLabelSorted | kUnweighted | + kTopSorted; + if (outprops & kTopSorted) { + outprops |= kAcyclic | kInitialAcyclic; + } + return outprops; +} + +extern const char *PropertyNames[]; + +} // namespace fst + +#endif // FST_PROPERTIES_H_ diff --git a/projects/llm_framework/include/fst/prune.h b/projects/llm_framework/include/fst/prune.h new file mode 100644 index 00000000..9e7c04bd --- /dev/null +++ b/projects/llm_framework/include/fst/prune.h @@ -0,0 +1,341 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions implementing pruning. + +#ifndef FST_PRUNE_H_ +#define FST_PRUNE_H_ + +#include +#include +#include + +#include + +#include +#include +#include + + +namespace fst { +namespace internal { + +template +class PruneCompare { + public: + PruneCompare(const std::vector &idistance, + const std::vector &fdistance) + : idistance_(idistance), fdistance_(fdistance) {} + + bool operator()(const StateId x, const StateId y) const { + const auto wx = Times(IDistance(x), FDistance(x)); + const auto wy = Times(IDistance(y), FDistance(y)); + return less_(wx, wy); + } + + private: + Weight IDistance(const StateId s) const { + return s < idistance_.size() ? idistance_[s] : Weight::Zero(); + } + + Weight FDistance(const StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); + } + + const std::vector &idistance_; + const std::vector &fdistance_; + NaturalLess less_; +}; + +} // namespace internal + +template +struct PruneOptions { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId, + ArcFilter filter = ArcFilter(), + std::vector *distance = nullptr, + float delta = kDelta, bool threshold_initial = false) + : weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold), + filter(std::move(filter)), + distance(distance), + delta(delta), + threshold_initial(threshold_initial) {} + + // Pruning weight threshold. + Weight weight_threshold; + // Pruning state threshold. + StateId state_threshold; + // Arc filter. + ArcFilter filter; + // If non-zero, passes in pre-computed shortest distance to final states. + const std::vector *distance; + // Determines the degree of convergence required when computing shortest + // distances. + float delta; + // Determines if the shortest path weight is left (true) or right + // (false) multiplied by the threshold to get the limit for + // keeping a state or arc (matters if the semiring is not + // commutative). + bool threshold_initial; +}; + +// Pruning algorithm: this version modifies its input and it takes an options +// class as an argument. After pruning the FST contains states and arcs that +// belong to a successful path in the FST whose weight is no more than the +// weight of the shortest path Times() the provided weight threshold. When the +// state threshold is not kNoStateId, the output FST is further restricted to +// have no more than the number of states in opts.state_threshold. Weights must +// have the path property. The weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template ::value>::type * = + nullptr> +void Prune(MutableFst *fst, const PruneOptions &opts = + PruneOptions()) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using StateHeap = Heap>; + auto ns = fst->NumStates(); + if (ns < 1) return; + std::vector idistance(ns, Weight::Zero()); + std::vector tmp; + if (!opts.distance) { + tmp.reserve(ns); + ShortestDistance(*fst, &tmp, true, opts.delta); + } + const auto *fdistance = opts.distance ? opts.distance : &tmp; + if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) || + ((*fdistance)[fst->Start()] == Weight::Zero())) { + fst->DeleteStates(); + return; + } + internal::PruneCompare compare(idistance, *fdistance); + StateHeap heap(compare); + std::vector visited(ns, false); + std::vector enqueued(ns, StateHeap::kNoKey); + std::vector dead; + dead.push_back(fst->AddState()); + NaturalLess less; + auto s = fst->Start(); + const auto limit = opts.threshold_initial ? + Times(opts.weight_threshold, (*fdistance)[s]) : + Times((*fdistance)[s], opts.weight_threshold); + StateId num_visited = 0; + + if (!less(limit, (*fdistance)[s])) { + idistance[s] = Weight::One(); + enqueued[s] = heap.Insert(s); + ++num_visited; + } + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = StateHeap::kNoKey; + visited[s] = true; + if (less(limit, Times(idistance[s], fst->Final(s)))) { + fst->SetFinal(s, Weight::Zero()); + } + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); // Copy intended. + if (!opts.filter(arc)) continue; + const auto weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() ? + (*fdistance)[arc.nextstate] : Weight::Zero()); + if (less(limit, weight)) { + arc.nextstate = dead[0]; + aiter.SetValue(arc); + continue; + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + } + if (visited[arc.nextstate]) continue; + if ((opts.state_threshold != kNoStateId) && + (num_visited >= opts.state_threshold)) { + continue; + } + if (enqueued[arc.nextstate] == StateHeap::kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + ++num_visited; + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } + for (StateId i = 0; i < visited.size(); ++i) { + if (!visited[i]) dead.push_back(i); + } + fst->DeleteStates(dead); +} + +template ::value>::type + * = nullptr> +void Prune(MutableFst *fst, const PruneOptions &opts = + PruneOptions()) { + FSTERROR() << "Prune: Weight needs to have the path property: " + << Arc::Weight::Type(); + fst->SetProperties(kError, kError); +} + +// Pruning algorithm: this version modifies its input and takes the +// pruning threshold as an argument. It deletes states and arcs in the +// FST that do not belong to a successful path whose weight is more +// than the weight of the shortest path Times() the provided weight +// threshold. When the state threshold is not kNoStateId, the output +// FST is further restricted to have no more than the number of states +// in opts.state_threshold. Weights must have the path property. The +// weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template +void Prune(MutableFst *fst, typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + const PruneOptions> opts( + weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); + Prune(fst, opts); +} + +// Pruning algorithm: this version writes the pruned input FST to an +// output MutableFst and it takes an options class as an argument. The +// output FST contains states and arcs that belong to a successful +// path in the input FST whose weight is more than the weight of the +// shortest path Times() the provided weight threshold. When the state +// threshold is not kNoStateId, the output FST is further restricted +// to have no more than the number of states in +// opts.state_threshold. Weights have the path property. The weight +// of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template ::value>::type * = + nullptr> +void Prune( + const Fst &ifst, MutableFst *ofst, + const PruneOptions &opts = PruneOptions()) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using StateHeap = Heap>; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Start() == kNoStateId) return; + NaturalLess less; + if (less(opts.weight_threshold, Weight::One()) || + (opts.state_threshold == 0)) { + return; + } + std::vector idistance; + std::vector tmp; + if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta); + const auto *fdistance = opts.distance ? opts.distance : &tmp; + if ((fdistance->size() <= ifst.Start()) || + ((*fdistance)[ifst.Start()] == Weight::Zero())) { + return; + } + internal::PruneCompare compare(idistance, *fdistance); + StateHeap heap(compare); + std::vector copy; + std::vector enqueued; + std::vector visited; + auto s = ifst.Start(); + const auto limit = opts.threshold_initial ? + Times(opts.weight_threshold, (*fdistance)[s]) : + Times((*fdistance)[s], opts.weight_threshold); + while (copy.size() <= s) copy.push_back(kNoStateId); + copy[s] = ofst->AddState(); + ofst->SetStart(copy[s]); + while (idistance.size() <= s) idistance.push_back(Weight::Zero()); + idistance[s] = Weight::One(); + while (enqueued.size() <= s) { + enqueued.push_back(StateHeap::kNoKey); + visited.push_back(false); + } + enqueued[s] = heap.Insert(s); + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = StateHeap::kNoKey; + visited[s] = true; + if (!less(limit, Times(idistance[s], ifst.Final(s)))) { + ofst->SetFinal(copy[s], ifst.Final(s)); + } + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (!opts.filter(arc)) continue; + const auto weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() ? + (*fdistance)[arc.nextstate] : Weight::Zero()); + if (less(limit, weight)) continue; + if ((opts.state_threshold != kNoStateId) && + (ofst->NumStates() >= opts.state_threshold)) { + continue; + } + while (idistance.size() <= arc.nextstate) { + idistance.push_back(Weight::Zero()); + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + } + while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId); + if (copy[arc.nextstate] == kNoStateId) { + copy[arc.nextstate] = ofst->AddState(); + } + ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, + copy[arc.nextstate])); + while (enqueued.size() <= arc.nextstate) { + enqueued.push_back(StateHeap::kNoKey); + visited.push_back(false); + } + if (visited[arc.nextstate]) continue; + if (enqueued[arc.nextstate] == StateHeap::kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } +} + +template ::value>::type + * = nullptr> +void Prune(const Fst &, MutableFst *ofst, + const PruneOptions &) { + FSTERROR() << "Prune: Weight needs to have the path property: " + << Arc::Weight::Type(); + ofst->SetProperties(kError, kError); +} + +// Pruning algorithm: this version writes the pruned input FST to an +// output MutableFst and simply takes the pruning threshold as an +// argument. The output FST contains states and arcs that belong to a +// successful path in the input FST whose weight is no more than the +// weight of the shortest path Times() the provided weight +// threshold. When the state threshold is not kNoStateId, the output +// FST is further restricted to have no more than the number of states +// in opts.state_threshold. Weights must have the path property. The +// weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) = Weight::One(); +template +void Prune(const Fst &ifst, MutableFst *ofst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + const PruneOptions> opts( + weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); + Prune(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_PRUNE_H_ diff --git a/projects/llm_framework/include/fst/push.h b/projects/llm_framework/include/fst/push.h new file mode 100644 index 00000000..1f772739 --- /dev/null +++ b/projects/llm_framework/include/fst/push.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to reweight/push an FST, and utility functions to weigh and reweight +// an FST. + +#ifndef FST_PUSH_H_ +#define FST_PUSH_H_ + +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// Computes the total weight (sum of the weights of all accepting paths) from +// the output of ShortestDistance, using the shortest distance from the final +// state when reverse is true and from the initial state otherwise. +template +typename Arc::Weight ComputeTotalWeight( + const Fst &fst, const std::vector &distance, + bool reverse) { + if (reverse) { + return fst.Start() < distance.size() ? distance[fst.Start()] + : Arc::Weight::Zero(); + } + auto sum = Arc::Weight::Zero(); + for (typename Arc::StateId s = 0; s < distance.size(); ++s) { + sum = Plus(sum, Times(distance[s], fst.Final(s))); + } + return sum; +} + +// Divides the weight of every accepting path by a fixed weight. This weight +// is also divided at the final state if at_final is true and at the initial +// state otherwise. +template +void RemoveWeight(MutableFst *fst, const typename Arc::Weight &weight, + bool at_final) { + using Weight = typename Arc::Weight; + if ((weight == Weight::One()) || (weight == Weight::Zero())) return; + if (at_final) { + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + fst->SetFinal(siter.Value(), + Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT)); + } + } else { + const auto start = fst->Start(); + for (MutableArcIterator> aiter(fst, start); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT); + aiter.SetValue(arc); + } + fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT)); + } +} + +// Pushes the weights in FST in the direction defined by TYPE. If +// pushing towards the initial state, the sum of the weight of the +// outgoing transitions and final weight at a non-initial state is +// equal to One() in the resulting machine. If pushing towards the +// final state, the same property holds on the reverse machine. +// +// Weight needs to be left distributive when pushing towards the +// initial state and right distributive when pushing towards the final +// states. +template +void Push(MutableFst *fst, ReweightType type, float delta = kDelta, + bool remove_total_weight = false) { + using Weight = typename Arc::Weight; + std::vector distance; + ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); + auto total_weight = Weight::One(); + if (remove_total_weight) { + total_weight = + ComputeTotalWeight(*fst, distance, type == REWEIGHT_TO_INITIAL); + } + Reweight(fst, distance, type); + if (remove_total_weight) { + RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); + } +} + +constexpr uint32 kPushWeights = 0x0001; +constexpr uint32 kPushLabels = 0x0002; +constexpr uint32 kPushRemoveTotalWeight = 0x0004; +constexpr uint32 kPushRemoveCommonAffix = 0x0008; + +// Pushes the weights and/or labels of the input FST into the output +// mutable FST by pushing weights and/or labels (as determined by the +// ptype argument) towards the initial state or final states (as +// determined by the rtype template parameter). The weight type must +// be left distributive when pushing weights towards the initial state, and +// right distribution when pushing weights towards the final states. +template +void Push(const Fst &ifst, MutableFst *ofst, uint32 ptype, + float delta = kDelta) { + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { + *ofst = ifst; + Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); + } else if (ptype & kPushLabels) { + const auto gtype = + rtype == REWEIGHT_TO_INITIAL ? GALLIC_LEFT : GALLIC_RIGHT; + using GallicWeight = typename GallicArc::Weight; + std::vector gdistance; + VectorFst> gfst; + ArcMap(ifst, &gfst, ToGallicMapper()); + if (ptype & kPushWeights) { + ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } else { + ArcMapFst> uwfst(ifst, + RmWeightMapper()); + ArcMapFst, ToGallicMapper> guwfst( + uwfst, ToGallicMapper()); + ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } + auto total_weight = GallicWeight::One(); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + total_weight = + ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); + total_weight = GallicWeight( + ptype & kPushRemoveCommonAffix + ? total_weight.Value1() + : StringWeight::One(), + ptype & kPushRemoveTotalWeight ? total_weight.Value2() + : Weight::One()); + } + Reweight(&gfst, gdistance, rtype); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); + } + FactorWeightFst, GallicFactor> + fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else { + LOG(WARNING) << "Push: pushing type is set to 0, so not pushing"; + *ofst = ifst; + } +} + +} // namespace fst + +#endif // FST_PUSH_H_ diff --git a/projects/llm_framework/include/fst/queue.h b/projects/llm_framework/include/fst/queue.h new file mode 100644 index 00000000..f57d176e --- /dev/null +++ b/projects/llm_framework/include/fst/queue.h @@ -0,0 +1,948 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes for various FST state queues with a unified interface. + +#ifndef FST_QUEUE_H_ +#define FST_QUEUE_H_ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// The Queue interface is: +// +// template +// class Queue { +// public: +// using StateId = S; +// +// // Constructor: may need args (e.g., FST, comparator) for some queues. +// Queue(...) override; +// +// // Returns the head of the queue. +// StateId Head() const override; +// +// // Inserts a state. +// void Enqueue(StateId s) override; +// +// // Removes the head of the queue. +// void Dequeue() override; +// +// // Updates ordering of state s when weight changes, if necessary. +// void Update(StateId s) override; +// +// // Is the queue empty? +// bool Empty() const override; +// +// // Removes all states from the queue. +// void Clear() override; +// }; + +// State queue types. +enum QueueType { + TRIVIAL_QUEUE = 0, // Single state queue. + FIFO_QUEUE = 1, // First-in, first-out queue. + LIFO_QUEUE = 2, // Last-in, first-out queue. + SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue. + TOP_ORDER_QUEUE = 4, // Topologically-ordered queue. + STATE_ORDER_QUEUE = 5, // State ID-ordered queue. + SCC_QUEUE = 6, // Component graph top-ordered meta-queue. + AUTO_QUEUE = 7, // Auto-selected queue. + OTHER_QUEUE = 8 +}; + +// QueueBase, templated on the StateId, is a virtual base class shared by all +// queues considered by AutoQueue. +template +class QueueBase { + public: + using StateId = S; + + virtual ~QueueBase() {} + + // Concrete implementation. + + explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {} + + void SetError(bool error) { error_ = error; } + + bool Error() const { return error_; } + + QueueType Type() const { return queue_type_; } + + // Virtual interface. + + virtual StateId Head() const = 0; + virtual void Enqueue(StateId) = 0; + virtual void Dequeue() = 0; + virtual void Update(StateId) = 0; + virtual bool Empty() const = 0; + virtual void Clear() = 0; + + private: + QueueType queue_type_; + bool error_; +}; + +// Trivial queue discipline; one may enqueue at most one state at a time. It +// can be used for strongly connected components with only one state and no +// self-loops. +template +class TrivialQueue : public QueueBase { + public: + using StateId = S; + + TrivialQueue() : QueueBase(TRIVIAL_QUEUE), front_(kNoStateId) {} + + virtual ~TrivialQueue() = default; + + StateId Head() const final { return front_; } + + void Enqueue(StateId s) final { front_ = s; } + + void Dequeue() final { front_ = kNoStateId; } + + void Update(StateId) final {} + + bool Empty() const final { return front_ == kNoStateId; } + + void Clear() final { front_ = kNoStateId; } + + private: + StateId front_; +}; + +// First-in, first-out queue discipline. +// +// This is not a final class. +template +class FifoQueue : public QueueBase { + public: + using StateId = S; + + FifoQueue() : QueueBase(FIFO_QUEUE) {} + + virtual ~FifoQueue() = default; + + StateId Head() const override { return queue_.back(); } + + void Enqueue(StateId s) override { queue_.push_front(s); } + + void Dequeue() override { queue_.pop_back(); } + + void Update(StateId) override {} + + bool Empty() const override { return queue_.empty(); } + + void Clear() override { queue_.clear(); } + + private: + std::deque queue_; +}; + +// Last-in, first-out queue discipline. +template +class LifoQueue : public QueueBase { + public: + using StateId = S; + + LifoQueue() : QueueBase(LIFO_QUEUE) {} + + virtual ~LifoQueue() = default; + + StateId Head() const final { return queue_.front(); } + + void Enqueue(StateId s) final { queue_.push_front(s); } + + void Dequeue() final { queue_.pop_front(); } + + void Update(StateId) final {} + + bool Empty() const final { return queue_.empty(); } + + void Clear() final { queue_.clear(); } + + private: + std::deque queue_; +}; + +// Shortest-first queue discipline, templated on the StateId and as well as a +// comparison functor used to compare two StateIds. If a (single) state's order +// changes, it can be reordered in the queue with a call to Update(). If update +// is false, call to Update() does not reorder the queue. +// +// This is not a final class. +template +class ShortestFirstQueue : public QueueBase { + public: + using StateId = S; + + explicit ShortestFirstQueue(Compare comp) + : QueueBase(SHORTEST_FIRST_QUEUE), heap_(comp) {} + + virtual ~ShortestFirstQueue() = default; + + StateId Head() const override { return heap_.Top(); } + + void Enqueue(StateId s) override { + if (update) { + for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId); + key_[s] = heap_.Insert(s); + } else { + heap_.Insert(s); + } + } + + void Dequeue() override { + if (update) { + key_[heap_.Pop()] = kNoStateId; + } else { + heap_.Pop(); + } + } + + void Update(StateId s) override { + if (!update) return; + if (s >= key_.size() || key_[s] == kNoStateId) { + Enqueue(s); + } else { + heap_.Update(key_[s], s); + } + } + + bool Empty() const override { return heap_.Empty(); } + + void Clear() override { + heap_.Clear(); + if (update) key_.clear(); + } + + const Compare &GetCompare() const { return heap_.GetCompare(); } + + private: + Heap heap_; + std::vector key_; +}; + +namespace internal { + +// Given a vector that maps from states to weights, and a comparison functor +// for weights, this class defines a comparison function object between states. +template +class StateWeightCompare { + public: + using Weight = typename Less::Weight; + + StateWeightCompare(const std::vector &weights, const Less &less) + : weights_(weights), less_(less) {} + + bool operator()(const StateId s1, const StateId s2) const { + return less_(weights_[s1], weights_[s2]); + } + + private: + // Borrowed references. + const std::vector &weights_; + const Less &less_; +}; + +} // namespace internal + +// Shortest-first queue discipline, templated on the StateId and Weight, is +// specialized to use the weight's natural order for the comparison function. +template +class NaturalShortestFirstQueue + : public ShortestFirstQueue< + S, internal::StateWeightCompare>> { + public: + using StateId = S; + using Compare = internal::StateWeightCompare>; + + explicit NaturalShortestFirstQueue(const std::vector &distance) + : ShortestFirstQueue(Compare(distance, less_)) {} + + virtual ~NaturalShortestFirstQueue() = default; + + private: + // This is non-static because the constructor for non-idempotent weights will + // result in an error. + const NaturalLess less_{}; +}; + +// In a shortest path computation on a lattice-like FST, we may keep many old +// nonviable paths as a part of the search. Since the search process always +// expands the lowest cost path next, that lowest cost path may be a very old +// nonviable path instead of one we expect to lead to a shortest path. +// +// For instance, suppose that the current best path in an alignment has +// traversed 500 arcs with a cost of 10. We may also have a bad path in +// the queue that has traversed only 40 arcs but also has a cost of 10. +// This path is very unlikely to lead to a reasonable alignment, so this queue +// can prune it from the search space. +// +// This queue relies on the caller using a shortest-first exploration order +// like this: +// while (true) { +// StateId head = queue.Head(); +// queue.Dequeue(); +// for (const auto& arc : GetArcs(fst, head)) { +// queue.Enqueue(arc.nextstate); +// } +// } +// We use this assumption to guess that there is an arc between Head and the +// Enqueued state; this is how the number of path steps is measured. +template +class PruneNaturalShortestFirstQueue + : public NaturalShortestFirstQueue { + public: + using StateId = S; + using Base = NaturalShortestFirstQueue; + + explicit PruneNaturalShortestFirstQueue(const std::vector &distance, + int threshold) + : Base(distance), + threshold_(threshold), + head_steps_(0), + max_head_steps_(0) {} + + ~PruneNaturalShortestFirstQueue() override = default; + + StateId Head() const override { + const auto head = Base::Head(); + // Stores the number of steps from the start of the graph to this state + // along the shortest-weight path. + if (head < steps_.size()) { + max_head_steps_ = std::max(steps_[head], max_head_steps_); + head_steps_ = steps_[head]; + } + return head; + } + + void Enqueue(StateId s) override { + // We assume that there is an arc between the Head() state and this + // Enqueued state. + const ssize_t state_steps = head_steps_ + 1; + if (s >= steps_.size()) { + steps_.resize(s + 1, state_steps); + } + // This is the number of arcs in the minimum cost path from Start to s. + steps_[s] = state_steps; + if (state_steps > (max_head_steps_ - threshold_) || threshold_ < 0) { + Base::Enqueue(s); + } + } + + private: + // A dense map from StateId to the number of arcs in the minimum weight + // path from Start to this state. + std::vector steps_; + // We only keep paths that are within this number of arcs (not weight!) + // of the longest path. + const ssize_t threshold_; + + // The following are mutable because Head() is const. + // The number of arcs traversed in the minimum cost path from the start + // state to the current Head() state. + mutable ssize_t head_steps_; + // The maximum number of arcs traversed by any low-cost path so far. + mutable ssize_t max_head_steps_; +}; + +// Topological-order queue discipline, templated on the StateId. States are +// ordered in the queue topologically. The FST must be acyclic. +template +class TopOrderQueue : public QueueBase { + public: + using StateId = S; + + // This constructor computes the topological order. It accepts an arc filter + // to limit the transitions considered in that computation (e.g., only the + // epsilon graph). + template + TopOrderQueue(const Fst &fst, ArcFilter filter) + : QueueBase(TOP_ORDER_QUEUE), + front_(0), + back_(kNoStateId), + order_(0), + state_(0) { + bool acyclic; + TopOrderVisitor top_order_visitor(&order_, &acyclic); + DfsVisit(fst, &top_order_visitor, filter); + if (!acyclic) { + FSTERROR() << "TopOrderQueue: FST is not acyclic"; + QueueBase::SetError(true); + } + state_.resize(order_.size(), kNoStateId); + } + + // This constructor is passed the pre-computed topological order. + explicit TopOrderQueue(const std::vector &order) + : QueueBase(TOP_ORDER_QUEUE), + front_(0), + back_(kNoStateId), + order_(order), + state_(order.size(), kNoStateId) {} + + virtual ~TopOrderQueue() = default; + + StateId Head() const final { return state_[front_]; } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = order_[s]; + } else if (order_[s] > back_) { + back_ = order_[s]; + } else if (order_[s] < front_) { + front_ = order_[s]; + } + state_[order_[s]] = s; + } + + void Dequeue() final { + state_[front_] = kNoStateId; + while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; + } + + void Update(StateId) final {} + + bool Empty() const final { return front_ > back_; } + + void Clear() final { + for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId; + back_ = kNoStateId; + front_ = 0; + } + + private: + StateId front_; + StateId back_; + std::vector order_; + std::vector state_; +}; + +// State order queue discipline, templated on the StateId. States are ordered in +// the queue by state ID. +template +class StateOrderQueue : public QueueBase { + public: + using StateId = S; + + StateOrderQueue() + : QueueBase(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} + + virtual ~StateOrderQueue() = default; + + StateId Head() const final { return front_; } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = s; + } else if (s > back_) { + back_ = s; + } else if (s < front_) { + front_ = s; + } + while (enqueued_.size() <= s) enqueued_.push_back(false); + enqueued_[s] = true; + } + + void Dequeue() final { + enqueued_[front_] = false; + while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; + } + + void Update(StateId) final {} + + bool Empty() const final { return front_ > back_; } + + void Clear() final { + for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; + front_ = 0; + back_ = kNoStateId; + } + + private: + StateId front_; + StateId back_; + std::vector enqueued_; +}; + +// SCC topological-order meta-queue discipline, templated on the StateId and a +// queue used inside each SCC. It visits the SCCs of an FST in topological +// order. Its constructor is passed the queues to to use within an SCC. +template +class SccQueue : public QueueBase { + public: + using StateId = S; + + // Constructor takes a vector specifying the SCC number per state and a + // vector giving the queue to use per SCC number. + SccQueue(const std::vector &scc, + std::vector> *queue) + : QueueBase(SCC_QUEUE), + queue_(queue), + scc_(scc), + front_(0), + back_(kNoStateId) {} + + virtual ~SccQueue() = default; + + StateId Head() const final { + while ((front_ <= back_) && + (((*queue_)[front_] && (*queue_)[front_]->Empty()) || + (((*queue_)[front_] == nullptr) && + ((front_ >= trivial_queue_.size()) || + (trivial_queue_[front_] == kNoStateId))))) { + ++front_; + } + if ((*queue_)[front_]) { + return (*queue_)[front_]->Head(); + } else { + return trivial_queue_[front_]; + } + } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = scc_[s]; + } else if (scc_[s] > back_) { + back_ = scc_[s]; + } else if (scc_[s] < front_) { + front_ = scc_[s]; + } + if ((*queue_)[scc_[s]]) { + (*queue_)[scc_[s]]->Enqueue(s); + } else { + while (trivial_queue_.size() <= scc_[s]) { + trivial_queue_.push_back(kNoStateId); + } + trivial_queue_[scc_[s]] = s; + } + } + + void Dequeue() final { + if ((*queue_)[front_]) { + (*queue_)[front_]->Dequeue(); + } else if (front_ < trivial_queue_.size()) { + trivial_queue_[front_] = kNoStateId; + } + } + + void Update(StateId s) final { + if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s); + } + + bool Empty() const final { + // Queues SCC number back_ is not empty unless back_ == front_. + if (front_ < back_) { + return false; + } else if (front_ > back_) { + return true; + } else if ((*queue_)[front_]) { + return (*queue_)[front_]->Empty(); + } else { + return (front_ >= trivial_queue_.size()) || + (trivial_queue_[front_] == kNoStateId); + } + } + + void Clear() final { + for (StateId i = front_; i <= back_; ++i) { + if ((*queue_)[i]) { + (*queue_)[i]->Clear(); + } else if (i < trivial_queue_.size()) { + trivial_queue_[i] = kNoStateId; + } + } + front_ = 0; + back_ = kNoStateId; + } + + private: + std::vector> *queue_; + const std::vector &scc_; + mutable StateId front_; + StateId back_; + std::vector trivial_queue_; +}; + +// Automatic queue discipline. It selects a queue discipline for a given FST +// based on its properties. +template +class AutoQueue : public QueueBase { + public: + using StateId = S; + + // This constructor takes a state distance vector that, if non-null and if + // the Weight type has the path property, will entertain the shortest-first + // queue using the natural order w.r.t to the distance. + template + AutoQueue(const Fst &fst, + const std::vector *distance, ArcFilter filter) + : QueueBase(AUTO_QUEUE) { + using Weight = typename Arc::Weight; + using Less = NaturalLess; + using Compare = internal::StateWeightCompare; + // First checks if the FST is known to have these properties. + const auto props = + fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false); + if ((props & kTopSorted) || fst.Start() == kNoStateId) { + queue_.reset(new StateOrderQueue()); + VLOG(2) << "AutoQueue: using state-order discipline"; + } else if (props & kAcyclic) { + queue_.reset(new TopOrderQueue(fst, filter)); + VLOG(2) << "AutoQueue: using top-order discipline"; + } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { + queue_.reset(new LifoQueue()); + VLOG(2) << "AutoQueue: using LIFO discipline"; + } else { + uint64 properties; + // Decomposes into strongly-connected components. + SccVisitor scc_visitor(&scc_, nullptr, nullptr, &properties); + DfsVisit(fst, &scc_visitor, filter); + auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1; + std::vector queue_types(nscc); + std::unique_ptr less; + std::unique_ptr comp; + if (distance && (Weight::Properties() & kPath) == kPath) { + less.reset(new Less); + comp.reset(new Compare(*distance, *less)); + } + // Finds the queue type to use per SCC. + bool unweighted; + bool all_trivial; + SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial, + &unweighted); + // If unweighted and semiring is idempotent, uses LIFO queue. + if (unweighted) { + queue_.reset(new LifoQueue()); + VLOG(2) << "AutoQueue: using LIFO discipline"; + return; + } + // If all the SCC are trivial, the FST is acyclic and the scc number gives + // the topological order. + if (all_trivial) { + queue_.reset(new TopOrderQueue(scc_)); + VLOG(2) << "AutoQueue: using top-order discipline"; + return; + } + VLOG(2) << "AutoQueue: using SCC meta-discipline"; + queues_.resize(nscc); + for (StateId i = 0; i < nscc; ++i) { + switch (queue_types[i]) { + case TRIVIAL_QUEUE: + queues_[i].reset(); + VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline"; + break; + case SHORTEST_FIRST_QUEUE: + queues_[i].reset( + new ShortestFirstQueue(*comp)); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using shortest-first discipline"; + break; + case LIFO_QUEUE: + queues_[i].reset(new LifoQueue()); + VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline"; + break; + case FIFO_QUEUE: + default: + queues_[i].reset(new FifoQueue()); + VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine"; + break; + } + } + queue_.reset(new SccQueue>(scc_, &queues_)); + } + } + + virtual ~AutoQueue() = default; + + StateId Head() const final { return queue_->Head(); } + + void Enqueue(StateId s) final { queue_->Enqueue(s); } + + void Dequeue() final { queue_->Dequeue(); } + + void Update(StateId s) final { queue_->Update(s); } + + bool Empty() const final { return queue_->Empty(); } + + void Clear() final { queue_->Clear(); } + + private: + template + static void SccQueueType(const Fst &fst, const std::vector &scc, + std::vector *queue_types, + ArcFilter filter, Less *less, bool *all_trivial, + bool *unweighted); + + std::unique_ptr> queue_; + std::vector>> queues_; + std::vector scc_; +}; + +// Examines the states in an FST's strongly connected components and determines +// which type of queue to use per SCC. Stores result as a vector of QueueTypes +// which is assumed to have length equal to the number of SCCs. An arc filter +// is used to limit the transitions considered (e.g., only the epsilon graph). +// The argument all_trivial is set to true if every queue is the trivial queue. +// The argument unweighted is set to true if the semiring is idempotent and all +// the arc weights are equal to Zero() or One(). +template +template +void AutoQueue::SccQueueType(const Fst &fst, + const std::vector &scc, + std::vector *queue_type, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + *all_trivial = true; + *unweighted = true; + for (StateId i = 0; i < queue_type->size(); ++i) { + (*queue_type)[i] = TRIVIAL_QUEUE; + } + for (StateIterator> sit(fst); !sit.Done(); sit.Next()) { + const auto state = sit.Value(); + for (ArcIterator> ait(fst, state); !ait.Done(); ait.Next()) { + const auto &arc = ait.Value(); + if (!filter(arc)) continue; + if (scc[state] == scc[arc.nextstate]) { + auto &type = (*queue_type)[scc[state]]; + if (!less || ((*less)(arc.weight, Weight::One()))) { + type = FIFO_QUEUE; + } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { + type = SHORTEST_FIRST_QUEUE; + } else { + type = LIFO_QUEUE; + } + } + if (type != TRIVIAL_QUEUE) *all_trivial = false; + } + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { + *unweighted = false; + } + } + } +} + +// An A* estimate is a function object that maps from a state ID to an +// estimate of the shortest distance to the final states. + +// A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's +// algorithm. +template +struct TrivialAStarEstimate { + constexpr Weight operator()(StateId) const { return Weight::One(); } +}; + +// A non-trivial A* estimate using a vector of the estimated future costs. +template +class NaturalAStarEstimate { + public: + NaturalAStarEstimate(const std::vector &beta) : beta_(beta) {} + + const Weight &operator()(StateId s) const { + return (s < beta_.size()) ? beta_[s] : kZero; + } + + private: + static constexpr Weight kZero = Weight::Zero(); + + const std::vector &beta_; +}; + +template +constexpr Weight NaturalAStarEstimate::kZero; + +// Given a vector that maps from states to weights representing the shortest +// distance from the initial state, a comparison function object between +// weights, and an estimate of the shortest distance to the final states, this +// class defines a comparison function object between states. +template +class AStarWeightCompare { + public: + using StateId = S; + using Weight = typename Less::Weight; + + AStarWeightCompare(const std::vector &weights, const Less &less, + const Estimate &estimate) + : weights_(weights), less_(less), estimate_(estimate) {} + + bool operator()(StateId s1, StateId s2) const { + const auto w1 = Times(weights_[s1], estimate_(s1)); + const auto w2 = Times(weights_[s2], estimate_(s2)); + return less_(w1, w2); + } + + const Estimate &GetEstimate() const { return estimate_; } + + private: + const std::vector &weights_; + const Less &less_; + const Estimate &estimate_; +}; + +// A* queue discipline templated on StateId, Weight, and Estimate. +template +class NaturalAStarQueue : public ShortestFirstQueue< + S, AStarWeightCompare, Estimate>> { + public: + using StateId = S; + using Compare = AStarWeightCompare, Estimate>; + + NaturalAStarQueue(const std::vector &distance, + const Estimate &estimate) + : ShortestFirstQueue( + Compare(distance, less_, estimate)) {} + + ~NaturalAStarQueue() = default; + + private: + // This is non-static because the constructor for non-idempotent weights will + // result in an error. + const NaturalLess less_{}; +}; + +// A state equivalence class is a function object that maps from a state ID to +// an equivalence class (state) ID. The trivial equivalence class maps a state +// ID to itself. +template +struct TrivialStateEquivClass { + StateId operator()(StateId s) const { return s; } +}; + +// Distance-based pruning queue discipline: Enqueues a state only when its +// shortest distance (so far), as specified by distance, is less than (as +// specified by comp) the shortest distance Times() the threshold to any state +// in the same equivalence class, as specified by the functor class_func. The +// underlying queue discipline is specified by queue. The ownership of queue is +// given to this class. +// +// This is not a final class. +template +class PruneQueue : public QueueBase { + public: + using StateId = typename Queue::StateId; + using Weight = typename Less::Weight; + + PruneQueue(const std::vector &distance, Queue *queue, + const Less &less, const ClassFnc &class_fnc, Weight threshold) + : QueueBase(OTHER_QUEUE), + distance_(distance), + queue_(queue), + less_(less), + class_fnc_(class_fnc), + threshold_(std::move(threshold)) {} + + virtual ~PruneQueue() = default; + + StateId Head() const override { return queue_->Head(); } + + void Enqueue(StateId s) override { + const auto c = class_fnc_(s); + if (c >= class_distance_.size()) { + class_distance_.resize(c + 1, Weight::Zero()); + } + if (less_(distance_[s], class_distance_[c])) { + class_distance_[c] = distance_[s]; + } + // Enqueues only if below threshold limit. + const auto limit = Times(class_distance_[c], threshold_); + if (less_(distance_[s], limit)) queue_->Enqueue(s); + } + + void Dequeue() override { queue_->Dequeue(); } + + void Update(StateId s) override { + const auto c = class_fnc_(s); + if (less_(distance_[s], class_distance_[c])) { + class_distance_[c] = distance_[s]; + } + queue_->Update(s); + } + + bool Empty() const override { return queue_->Empty(); } + + void Clear() override { queue_->Clear(); } + + private: + const std::vector &distance_; // Shortest distance to state. + std::unique_ptr queue_; + const Less &less_; // Borrowed reference. + const ClassFnc &class_fnc_; // Equivalence class functor. + Weight threshold_; // Pruning weight threshold. + std::vector class_distance_; // Shortest distance to class. +}; + +// Pruning queue discipline (see above) using the weight's natural order for the +// comparison function. The ownership of the queue argument is given to this +// class. +template +class NaturalPruneQueue final + : public PruneQueue, ClassFnc> { + public: + using StateId = typename Queue::StateId; + + NaturalPruneQueue(const std::vector &distance, Queue *queue, + const ClassFnc &class_fnc, Weight threshold) + : PruneQueue, ClassFnc>( + distance, queue, NaturalLess(), class_fnc, threshold) {} + + virtual ~NaturalPruneQueue() = default; +}; + +// Filter-based pruning queue discipline: enqueues a state only if allowed by +// the filter, specified by the state filter functor argument. The underlying +// queue discipline is specified by the queue argument. The ownership of the +// queue is given to this class. +template +class FilterQueue : public QueueBase { + public: + using StateId = typename Queue::StateId; + + FilterQueue(Queue *queue, const Filter &filter) + : QueueBase(OTHER_QUEUE), queue_(queue), filter_(filter) {} + + virtual ~FilterQueue() = default; + + StateId Head() const final { return queue_->Head(); } + + // Enqueues only if allowed by state filter. + void Enqueue(StateId s) final { + if (filter_(s)) queue_->Enqueue(s); + } + + void Dequeue() final { queue_->Dequeue(); } + + void Update(StateId s) final {} + + bool Empty() const final { return queue_->Empty(); } + + void Clear() final { queue_->Clear(); } + + private: + std::unique_ptr queue_; + const Filter &filter_; +}; + +} // namespace fst + +#endif // FST_QUEUE_H_ diff --git a/projects/llm_framework/include/fst/randequivalent.h b/projects/llm_framework/include/fst/randequivalent.h new file mode 100644 index 00000000..73108b46 --- /dev/null +++ b/projects/llm_framework/include/fst/randequivalent.h @@ -0,0 +1,114 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Tests if two FSTS are equivalent by checking if random strings from one FST +// are transduced the same by both FSTs. + +#ifndef FST_RANDEQUIVALENT_H_ +#define FST_RANDEQUIVALENT_H_ + +#include + +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Test if two FSTs are stochastically equivalent by randomly generating +// random paths through the FSTs. +// +// For each randomly generated path, the algorithm computes for each +// of the two FSTs the sum of the weights of all the successful paths +// sharing the same input and output labels as the considered randomly +// generated path and checks that these two values are within a user-specified +// delta. Returns optional error value (when FLAGS_error_fatal = false). +template +bool RandEquivalent(const Fst &fst1, const Fst &fst2, + int32 num_paths, float delta, + const RandGenOptions &opts, + bool *error = nullptr) { + using Weight = typename Arc::Weight; + if (error) *error = false; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st " + << "argument do not match input/output symbol tables of 2nd " + << "argument"; + if (error) *error = true; + return false; + } + static const ILabelCompare icomp; + static const OLabelCompare ocomp; + VectorFst sfst1(fst1); + VectorFst sfst2(fst2); + Connect(&sfst1); + Connect(&sfst2); + ArcSort(&sfst1, icomp); + ArcSort(&sfst2, icomp); + bool result = true; + for (int32 n = 0; n < num_paths; ++n) { + VectorFst path; + const auto &fst = rand() % 2 ? sfst1 : sfst2; // NOLINT + RandGen(fst, &path, opts); + VectorFst ipath(path); + VectorFst opath(path); + Project(&ipath, PROJECT_INPUT); + Project(&opath, PROJECT_OUTPUT); + VectorFst cfst1, pfst1; + Compose(ipath, sfst1, &cfst1); + ArcSort(&cfst1, ocomp); + Compose(cfst1, opath, &pfst1); + // Gives up if there are epsilon cycles in a non-idempotent semiring. + if (!(Weight::Properties() & kIdempotent) && + pfst1.Properties(kCyclic, true)) { + continue; + } + const auto sum1 = ShortestDistance(pfst1); + VectorFst cfst2; + Compose(ipath, sfst2, &cfst2); + ArcSort(&cfst2, ocomp); + VectorFst pfst2; + Compose(cfst2, opath, &pfst2); + // Gives up if there are epsilon cycles in a non-idempotent semiring. + if (!(Weight::Properties() & kIdempotent) && + pfst2.Properties(kCyclic, true)) { + continue; + } + const auto sum2 = ShortestDistance(pfst2); + if (!ApproxEqual(sum1, sum2, delta)) { + VLOG(1) << "Sum1 = " << sum1; + VLOG(1) << "Sum2 = " << sum2; + result = false; + break; + } + } + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + return result; +} + +// Tests if two FSTs are equivalent by randomly generating a nnum_paths paths +// (no longer than the path_length) using a user-specified seed, optionally +// indicating an error setting an optional error argument to true. +template +bool RandEquivalent(const Fst &fst1, const Fst &fst2, int32 num_paths, + float delta = kDelta, time_t seed = time(nullptr), + int32 max_length = std::numeric_limits::max(), + bool *error = nullptr) { + const UniformArcSelector uniform_selector(seed); + const RandGenOptions> opts(uniform_selector, + max_length); + return RandEquivalent(fst1, fst2, num_paths, delta, opts, error); +} + +} // namespace fst + +#endif // FST_RANDEQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/randgen.h b/projects/llm_framework/include/fst/randgen.h new file mode 100644 index 00000000..5bcd9fd0 --- /dev/null +++ b/projects/llm_framework/include/fst/randgen.h @@ -0,0 +1,756 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions to generate random paths through an FST. + +#ifndef FST_RANDGEN_H_ +#define FST_RANDGEN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// The RandGenFst class is roughly similar to ArcMapFst in that it takes two +// template parameters denoting the input and output arc types. However, it also +// takes an additional template parameter which specifies a sampler object which +// samples (with replacement) arcs from an FST state. The sampler in turn takes +// a template parameter for a selector object which actually chooses the arc. +// +// Arc selector functors are used to select a random transition given an FST +// state s, returning a number N such that 0 <= N <= NumArcs(s). If N is +// NumArcs(s), then the final weight is selected; otherwise the N-th arc is +// selected. It is assumed these are not applied to any state which is neither +// final nor has any arcs leaving it. + +// Randomly selects a transition using the uniform distribution. This class is +// not thread-safe. +template +class UniformArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // Constructs a selector with a non-deterministic seed. + UniformArcSelector() : rand_(std::random_device()()) {} + // Constructs a selector with a given seed. + explicit UniformArcSelector(uint64 seed) : rand_(seed) {} + + size_t operator()(const Fst &fst, StateId s) const { + const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero()); + return static_cast( + std::uniform_int_distribution<>(0, n - 1)(rand_)); + } + + private: + mutable std::mt19937_64 rand_; +}; + +// Randomly selects a transition w.r.t. the weights treated as negative log +// probabilities after normalizing for the total weight leaving the state. Zero +// transitions are disregarded. It assumed that Arc::Weight::Value() accesses +// the floating point representation of the weight. This class is not +// thread-safe. +template +class LogProbArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // Constructs a selector with a non-deterministic seed. + LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {} + // Constructs a selector with a given seed. + explicit LogProbArcSelector(uint64 seed) : seed_(seed), rand_(seed) {} + + size_t operator()(const Fst &fst, StateId s) const { + // Finds total weight leaving state. + auto sum = Log64Weight::Zero(); + ArcIterator> aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + sum = Plus(sum, to_log_weight_(arc.weight)); + } + sum = Plus(sum, to_log_weight_(fst.Final(s))); + const double threshold = + std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_); + auto p = Log64Weight::Zero(); + size_t n = 0; + for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) { + p = Plus(p, to_log_weight_(aiter.Value().weight)); + if (exp(-p.Value()) > threshold) return n; + } + return n; + } + + uint64 Seed() const { return seed_; } + + protected: + Log64Weight ToLogWeight(const Weight &weight) const { + return to_log_weight_(weight); + } + + std::mt19937_64 &MutableRand() const { return rand_; } + + private: + const uint64 seed_; + mutable std::mt19937_64 rand_; + const WeightConvert to_log_weight_{}; +}; + +// Useful alias when using StdArc. +using StdArcSelector = LogProbArcSelector; + +// Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight +// accumulation computations. This class is not thread-safe. +template +class FastLogProbArcSelector : public LogProbArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using LogProbArcSelector::MutableRand; + using LogProbArcSelector::ToLogWeight; + using LogProbArcSelector::operator(); + + // Constructs a selector with a non-deterministic seed. + FastLogProbArcSelector() : LogProbArcSelector() {} + // Constructs a selector with a given seed. + explicit FastLogProbArcSelector(uint64 seed) : LogProbArcSelector( + seed) {} + + size_t operator()(const Fst &fst, StateId s, + CacheLogAccumulator *accumulator) const { + accumulator->SetState(s); + ArcIterator> aiter(fst, s); + // Finds total weight leaving state. + const double sum = + ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s))) + .Value(); + const double r = -log(std::uniform_real_distribution<>(0, 1)( + MutableRand())); + Weight w = from_log_weight_(r + sum); + aiter.Reset(); + return accumulator->LowerBound(w, &aiter); + } + + private: + const WeightConvert from_log_weight_{}; +}; + +// Random path state info maintained by RandGenFst and passed to samplers. +template +struct RandState { + using StateId = typename Arc::StateId; + + StateId state_id; // Current input FST state. + size_t nsamples; // Number of samples to be sampled at this state. + size_t length; // Length of path to this random state. + size_t select; // Previous sample arc selection. + const RandState *parent; // Previous random state on this path. + + explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0, + size_t select = 0, const RandState *parent = nullptr) + : state_id(state_id), + nsamples(nsamples), + length(length), + select(select), + parent(parent) {} + + RandState() : RandState(kNoStateId) {} +}; + +// This class, given an arc selector, samples, with replacement, multiple random +// transitions from an FST's state. This is a generic version with a +// straightforward use of the arc selector. Specializations may be defined for +// arc selectors for greater efficiency or special behavior. +template +class ArcSampler { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // The max_length argument may be interpreted (or ignored) by a selector as + // it chooses. This generic version interprets this literally. + ArcSampler(const Fst &fst, const Selector &selector, + int32 max_length = std::numeric_limits::max()) + : fst_(fst), selector_(selector), max_length_(max_length) {} + + // Allow updating FST argument; pass only if changed. + ArcSampler(const ArcSampler &sampler, + const Fst *fst = nullptr) + : fst_(fst ? *fst : sampler.fst_), + selector_(sampler.selector_), + max_length_(sampler.max_length_) { + Reset(); + } + + // Samples a fixed number of samples from the given state. The length argument + // specifies the length of the path to the state. Returns true if the samples + // were collected. No samples may be collected if either there are no + // transitions leaving the state and the state is non-final, or if the path + // length has been exceeded. Iterator members are provided to read the samples + // in the order in which they were collected. + bool Sample(const RandState &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + for (size_t i = 0; i < rstate.nsamples; ++i) { + ++sample_map_[selector_(fst_, rstate.state_id)]; + } + Reset(); + return true; + } + + // More samples? + bool Done() const { return sample_iter_ == sample_map_.end(); } + + // Gets the next sample. + void Next() { ++sample_iter_; } + + std::pair Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return false; } + + private: + const Fst &fst_; + const Selector &selector_; + const int32 max_length_; + + // Stores (N, K) as described for Value(). + std::map sample_map_; + std::map::const_iterator sample_iter_; + + ArcSampler &operator=(const ArcSampler &) = delete; +}; + +// Samples one sample of num_to_sample dimensions from a multinomial +// distribution parameterized by a vector of probabilities. The result +// container should be pre-initialized (e.g., an empty map or a zeroed vector +// sized the same as the vector of probabilities. +// probs.size()). +template +void OneMultinomialSample(const std::vector &probs, + size_t num_to_sample, Result *result, RNG *rng) { + // Left-over probability mass. + double norm = 0; + for (double p : probs) norm += p; + // Left-over number of samples needed. + for (size_t i = 0; i < probs.size(); ++i) { + size_t num_sampled = 0; + if (probs[i] > 0) { + std::binomial_distribution<> d(num_to_sample, probs[i] / norm); + num_sampled = d(*rng); + } + if (num_sampled != 0) (*result)[i] = num_sampled; + norm -= probs[i]; + num_to_sample -= num_sampled; + } +} + +// Specialization for FastLogProbArcSelector. +template +class ArcSampler> { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Accumulator = CacheLogAccumulator; + using Selector = FastLogProbArcSelector; + + ArcSampler(const Fst &fst, const Selector &selector, + int32 max_length = std::numeric_limits::max()) + : fst_(fst), + selector_(selector), + max_length_(max_length), + accumulator_(new Accumulator()) { + accumulator_->Init(fst); + rng_.seed(selector_.Seed()); + } + + ArcSampler(const ArcSampler &sampler, + const Fst *fst = nullptr) + : fst_(fst ? *fst : sampler.fst_), + selector_(sampler.selector_), + max_length_(sampler.max_length_) { + if (fst) { + accumulator_.reset(new Accumulator()); + accumulator_->Init(*fst); + } else { // Shallow copy. + accumulator_.reset(new Accumulator(*sampler.accumulator_)); + } + } + + bool Sample(const RandState &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) { + MultinomialSample(rstate); + Reset(); + return true; + } + for (size_t i = 0; i < rstate.nsamples; ++i) { + ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())]; + } + Reset(); + return true; + } + + bool Done() const { return sample_iter_ == sample_map_.end(); } + + void Next() { ++sample_iter_; } + + std::pair Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return accumulator_->Error(); } + + private: + using RNG = std::mt19937; + + // Sample according to the multinomial distribution of rstate.nsamples draws + // from p_. + void MultinomialSample(const RandState &rstate) { + p_.clear(); + for (ArcIterator> aiter(fst_, rstate.state_id); !aiter.Done(); + aiter.Next()) { + p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value())); + } + if (fst_.Final(rstate.state_id) != Weight::Zero()) { + p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value())); + } + if (rstate.nsamples < std::numeric_limits::max()) { + OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_); + } else { + for (size_t i = 0; i < p_.size(); ++i) { + sample_map_[i] = ceil(p_[i] * rstate.nsamples); + } + } + } + + const Fst &fst_; + const Selector &selector_; + const int32 max_length_; + + // Stores (N, K) for Value(). + std::map sample_map_; + std::map::const_iterator sample_iter_; + + std::unique_ptr accumulator_; + RNG rng_; // Random number generator. + std::vector p_; // Multinomial parameters. + const WeightConvert to_log_weight_{}; +}; + +// Options for random path generation with RandGenFst. The template argument is +// a sampler, typically the class ArcSampler. Ownership of the sampler is taken +// by RandGenFst. +template +struct RandGenFstOptions : public CacheOptions { + Sampler *sampler; // How to sample transitions at a state. + int32 npath; // Number of paths to generate. + bool weighted; // Is the output tree weighted by path count, or + // is it just an unweighted DAG? + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath = 1, + bool weighted = true, bool remove_total_weight = false) + : CacheOptions(opts), + sampler(sampler), + npath(npath), + weighted(weighted), + remove_total_weight(remove_total_weight) {} +}; + +namespace internal { + +// Implementation of RandGenFst. +template +class RandGenFstImpl : public CacheImpl { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::EmplaceArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + using Label = typename FromArc::Label; + using StateId = typename FromArc::StateId; + using FromWeight = typename FromArc::Weight; + + using ToWeight = typename ToArc::Weight; + + RandGenFstImpl(const Fst &fst, + const RandGenFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + sampler_(opts.sampler), + npath_(opts.npath), + weighted_(opts.weighted), + remove_total_weight_(opts.remove_total_weight), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties( + RandGenProperties(fst.Properties(kFstProperties, false), weighted_), + kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RandGenFstImpl(const RandGenFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + sampler_(new Sampler(*impl.sampler_, fst_.get())), + npath_(impl.npath_), + weighted_(impl.weighted_), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) { + const auto s = fst_->Start(); + if (s == kNoStateId) return kNoStateId; + SetStart(state_table_.size()); + state_table_.emplace_back( + new RandState(s, npath_, 0, 0, nullptr)); + } + return CacheImpl::Start(); + } + + ToWeight Final(StateId s) { + if (!HasFinal(s)) Expand(s); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst_->Properties(kError, false) || sampler_->Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + if (s == superfinal_) { + SetFinal(s, ToWeight::One()); + SetArcs(s); + return; + } + SetFinal(s, ToWeight::Zero()); + const auto &rstate = *state_table_[s]; + sampler_->Sample(rstate); + ArcIterator> aiter(*fst_, rstate.state_id); + const auto narcs = fst_->NumArcs(rstate.state_id); + for (; !sampler_->Done(); sampler_->Next()) { + const auto &sample_pair = sampler_->Value(); + const auto pos = sample_pair.first; + const auto count = sample_pair.second; + double prob = static_cast(count) / rstate.nsamples; + if (pos < narcs) { // Regular transition. + aiter.Seek(sample_pair.first); + const auto &aarc = aiter.Value(); + auto weight = + weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One(); + EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight), + state_table_.size()); + auto *nrstate = new RandState(aarc.nextstate, count, + rstate.length + 1, pos, &rstate); + state_table_.emplace_back(nrstate); + } else { // Super-final transition. + if (weighted_) { + const auto weight = + remove_total_weight_ + ? to_weight_(Log64Weight(-log(prob))) + : to_weight_(Log64Weight(-log(prob * npath_))); + SetFinal(s, weight); + } else { + if (superfinal_ == kNoLabel) { + superfinal_ = state_table_.size(); + state_table_.emplace_back( + new RandState(kNoStateId, 0, 0, 0, nullptr)); + } + for (size_t n = 0; n < count; ++n) { + EmplaceArc(s, 0, 0, ToWeight::One(), superfinal_); + } + } + } + } + SetArcs(s); + } + + private: + const std::unique_ptr> fst_; + std::unique_ptr sampler_; + const int32 npath_; + std::vector>> state_table_; + const bool weighted_; + bool remove_total_weight_; + StateId superfinal_; + const WeightConvert to_weight_{}; +}; + +} // namespace internal + +// FST class to randomly generate paths through an FST, with details controlled +// by RandGenOptionsFst. Output format is a tree weighted by the path count. +template +class RandGenFst + : public ImplToFst> { + public: + using Label = typename FromArc::Label; + using StateId = typename FromArc::StateId; + using Weight = typename FromArc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using Impl = internal::RandGenFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + RandGenFst(const Fst &fst, const RandGenFstOptions &opts) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RandGenFst(const RandGenFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. + RandGenFst *Copy(bool safe = false) const override { + return new RandGenFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RandGenFst &operator=(const RandGenFst &) = delete; +}; + +// Specialization for RandGenFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const RandGenFst &fst) + : CacheStateIterator>( + fst, fst.GetMutableImpl()) {} +}; + +// Specialization for RandGenFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename FromArc::StateId; + + ArcIterator(const RandGenFst &fst, StateId s) + : CacheArcIterator>( + fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RandGenFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Options for random path generation. +template +struct RandGenOptions { + const Selector &selector; // How an arc is selected at a state. + int32 max_length; // Maximum path length. + int32 npath; // Number of paths to generate. + bool weighted; // Is the output tree weighted by path count, or + // is it just an unweighted DAG? + bool remove_total_weight; // Remove total weight when output is weighted? + + explicit RandGenOptions(const Selector &selector, + int32 max_length = std::numeric_limits::max(), + int32 npath = 1, bool weighted = false, + bool remove_total_weight = false) + : selector(selector), + max_length(max_length), + npath(npath), + weighted(weighted), + remove_total_weight(remove_total_weight) {} +}; + +namespace internal { + +template +class RandGenVisitor { + public: + using StateId = typename FromArc::StateId; + using Weight = typename FromArc::Weight; + + explicit RandGenVisitor(MutableFst *ofst) : ofst_(ofst) {} + + void InitVisit(const Fst &ifst) { + ifst_ = &ifst; + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst.InputSymbols()); + ofst_->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError); + path_.clear(); + } + + constexpr bool InitState(StateId, StateId) const { return true; } + + bool TreeArc(StateId, const ToArc &arc) { + if (ifst_->Final(arc.nextstate) == Weight::Zero()) { + path_.push_back(arc); + } else { + OutputPath(); + } + return true; + } + + bool BackArc(StateId, const FromArc &) { + FSTERROR() << "RandGenVisitor: cyclic input"; + ofst_->SetProperties(kError, kError); + return false; + } + + bool ForwardOrCrossArc(StateId, const FromArc &) { + OutputPath(); + return true; + } + + void FinishState(StateId s, StateId p, const FromArc *) { + if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back(); + } + + void FinishVisit() {} + + private: + void OutputPath() { + if (ofst_->Start() == kNoStateId) { + const auto start = ofst_->AddState(); + ofst_->SetStart(start); + } + auto src = ofst_->Start(); + for (size_t i = 0; i < path_.size(); ++i) { + const auto dest = ofst_->AddState(); + const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); + ofst_->AddArc(src, arc); + src = dest; + } + ofst_->SetFinal(src, Weight::One()); + } + + const Fst *ifst_; + MutableFst *ofst_; + std::vector path_; + + RandGenVisitor(const RandGenVisitor &) = delete; + RandGenVisitor &operator=(const RandGenVisitor &) = delete; +}; + +} // namespace internal + +// Randomly generate paths through an FST; details controlled by +// RandGenOptions. +template +void RandGen(const Fst &ifst, MutableFst *ofst, + const RandGenOptions &opts) { + using Sampler = ArcSampler; + auto *sampler = new Sampler(ifst, opts.selector, opts.max_length); + RandGenFstOptions fopts(CacheOptions(true, 0), sampler, opts.npath, + opts.weighted, opts.remove_total_weight); + RandGenFst rfst(ifst, fopts); + if (opts.weighted) { + *ofst = rfst; + } else { + internal::RandGenVisitor rand_visitor(ofst); + DfsVisit(rfst, &rand_visitor); + } +} + +// Randomly generate a path through an FST with the uniform distribution +// over the transitions. +template +void RandGen(const Fst &ifst, MutableFst *ofst) { + const UniformArcSelector uniform_selector; + RandGenOptions> opts(uniform_selector); + RandGen(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_RANDGEN_H_ diff --git a/projects/llm_framework/include/fst/rational.h b/projects/llm_framework/include/fst/rational.h new file mode 100644 index 00000000..184ebf3f --- /dev/null +++ b/projects/llm_framework/include/fst/rational.h @@ -0,0 +1,307 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation and base interface for delayed unions, concatenations, +// and closures. + +#ifndef FST_RATIONAL_H_ +#define FST_RATIONAL_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +using RationalFstOptions = CacheOptions; + +// This specifies whether to add the empty string. +enum ClosureType { + CLOSURE_STAR = 0, // Add the empty string. + CLOSURE_PLUS = 1 // Don't add the empty string. +}; + +template +class RationalFst; + +template +void Union(RationalFst *fst1, const Fst &fst2); + +template +void Concat(RationalFst *fst1, const Fst &fst2); + +template +void Concat(const Fst &fst1, RationalFst *fst2); + +template +void Closure(RationalFst *fst, ClosureType closure_type); + +namespace internal { + +// Implementation class for delayed unions, concatenations and closures. +template +class RationalFstImpl : public FstImpl { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + explicit RationalFstImpl(const RationalFstOptions &opts) + : nonterminals_(0), replace_options_(opts, 0) { + SetType("rational"); + fst_tuples_.emplace_back(0, nullptr); + } + + RationalFstImpl(const RationalFstImpl &impl) + : rfst_(impl.rfst_), + nonterminals_(impl.nonterminals_), + replace_(impl.replace_ ? impl.replace_->Copy(true) : nullptr), + replace_options_(impl.replace_options_) { + SetType("rational"); + fst_tuples_.reserve(impl.fst_tuples_.size()); + for (const auto &pair : impl.fst_tuples_) { + fst_tuples_.emplace_back(pair.first, + pair.second ? pair.second->Copy(true) : nullptr); + } + } + + ~RationalFstImpl() override { + for (auto &tuple : fst_tuples_) delete tuple.second; + } + + StateId Start() { return Replace()->Start(); } + + Weight Final(StateId s) { return Replace()->Final(s); } + + size_t NumArcs(StateId s) { return Replace()->NumArcs(s); } + + size_t NumInputEpsilons(StateId s) { return Replace()->NumInputEpsilons(s); } + + size_t NumOutputEpsilons(StateId s) { + return Replace()->NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && Replace()->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + // Implementation of UnionFst(fst1, fst2). + void InitUnion(const Fst &fst1, const Fst &fst2) { + replace_.reset(); + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(0, 0, -2, Weight::One(), 1); + fst_tuples_.emplace_back(-1, fst1.Copy()); + fst_tuples_.emplace_back(-2, fst2.Copy()); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ConcatFst(fst1, fst2). + void InitConcat(const Fst &fst1, const Fst &fst2) { + replace_.reset(); + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(2, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(1, 0, -2, Weight::One(), 2); + fst_tuples_.emplace_back(-1, fst1.Copy()); + fst_tuples_.emplace_back(-2, fst2.Copy()); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ClosureFst(fst, closure_type). + void InitClosure(const Fst &fst, ClosureType closure_type) { + replace_.reset(); + const auto props = fst.Properties(kFstProperties, false); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + if (closure_type == CLOSURE_STAR) { + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(0, Weight::One()); + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 0); + } else { + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(1, 0, 0, Weight::One(), 0); + } + rfst_.SetInputSymbols(fst.InputSymbols()); + rfst_.SetOutputSymbols(fst.OutputSymbols()); + fst_tuples_.emplace_back(-1, fst.Copy()); + nonterminals_ = 1; + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Implementation of Union(Fst &, RationalFst *). + void AddUnion(const Fst &fst) { + replace_.reset(); + const auto props1 = FstImpl::Properties(); + const auto props2 = fst.Properties(kFstProperties, false); + VectorFst afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); + Union(&rfst_, afst); + fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Concat(Fst &, RationalFst *). + void AddConcat(const Fst &fst, bool append) { + replace_.reset(); + const auto props1 = FstImpl::Properties(); + const auto props2 = fst.Properties(kFstProperties, false); + VectorFst afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); + if (append) { + Concat(&rfst_, afst); + } else { + Concat(afst, &rfst_); + } + fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Closure(RationalFst *, closure_type). + void AddClosure(ClosureType closure_type) { + replace_.reset(); + const auto props = FstImpl::Properties(); + Closure(&rfst_, closure_type); + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Returns the underlying ReplaceFst, preserving ownership of the underlying + // object. + ReplaceFst *Replace() const { + if (!replace_) { + fst_tuples_[0].second = rfst_.Copy(); + replace_.reset(new ReplaceFst(fst_tuples_, replace_options_)); + } + return replace_.get(); + } + + private: + // Rational topology machine, using negative non-terminals. + VectorFst rfst_; + // Number of nonterminals used. + Label nonterminals_; + // Contains the nonterminals and their corresponding FSTs. + mutable std::vector *>> fst_tuples_; + // Underlying ReplaceFst. + mutable std::unique_ptr> replace_; + const ReplaceFstOptions replace_options_; +}; + +} // namespace internal + +// Parent class for the delayed rational operations (union, concatenation, and +// closure). This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RationalFst : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::RationalFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + friend void Union<>(RationalFst *fst1, const Fst &fst2); + friend void Concat<>(RationalFst *fst1, const Fst &fst2); + friend void Concat<>(const Fst &fst1, RationalFst *fst2); + friend void Closure<>(RationalFst *fst, ClosureType closure_type); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->Replace()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->Replace()->InitArcIterator(s, data); + } + + protected: + using ImplToFst::GetImpl; + + explicit RationalFst(const RationalFstOptions &opts = RationalFstOptions()) + : ImplToFst(std::make_shared(opts)) {} + + // See Fst<>::Copy() for doc. + RationalFst(const RationalFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + private: + RationalFst &operator=(const RationalFst &) = delete; +}; + +// Specialization for RationalFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const RationalFst &fst) + : StateIterator>(*(fst.GetImpl()->Replace())) {} +}; + +// Specialization for RationalFst. +template +class ArcIterator> : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RationalFst &fst, StateId s) + : ArcIterator>(*(fst.GetImpl()->Replace()), s) {} +}; + +} // namespace fst + +#endif // FST_RATIONAL_H_ diff --git a/projects/llm_framework/include/fst/register.h b/projects/llm_framework/include/fst/register.h new file mode 100644 index 00000000..2d1a6ea7 --- /dev/null +++ b/projects/llm_framework/include/fst/register.h @@ -0,0 +1,115 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for registering derived FST for generic reading. + +#ifndef FST_REGISTER_H_ +#define FST_REGISTER_H_ + +#include +#include + + +#include +#include +#include + + +#include +#include + +namespace fst { + +template +class Fst; + +struct FstReadOptions; + +// This class represents a single entry in a FstRegister +template +struct FstRegisterEntry { + using Reader = Fst *(*)(std::istream &istrm, const FstReadOptions &opts); + using Converter = Fst *(*)(const Fst &fst); + + Reader reader; + Converter converter; + + explicit FstRegisterEntry(Reader reader = nullptr, + Converter converter = nullptr) + : reader(reader), converter(converter) {} +}; + +// This class maintains the correspondence between a string describing +// an FST type, and its reader and converter. +template +class FstRegister + : public GenericRegister, FstRegister> { + public: + using Reader = typename FstRegisterEntry::Reader; + using Converter = typename FstRegisterEntry::Converter; + + const Reader GetReader(const string &type) const { + return this->GetEntry(type).reader; + } + + const Converter GetConverter(const string &type) const { + return this->GetEntry(type).converter; + } + + protected: + string ConvertKeyToSoFilename(const string &key) const override { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-fst.so"; + } +}; + +// This class registers an FST type for generic reading and creating. +// The type must have a default constructor and a copy constructor from +// Fst. +template +class FstRegisterer : public GenericRegisterer> { + public: + using Arc = typename FST::Arc; + using Entry = typename FstRegister::Entry; + using Reader = typename FstRegister::Reader; + + FstRegisterer() + : GenericRegisterer>(FST().Type(), + BuildEntry()) {} + + private: + static Fst *ReadGeneric( + std::istream &strm, const FstReadOptions &opts) { + static_assert(std::is_base_of, FST>::value, + "FST class does not inherit from Fst"); + return FST::Read(strm, opts); + } + + static Entry BuildEntry() { + return Entry(&ReadGeneric, &FstRegisterer::Convert); + } + + static Fst *Convert(const Fst &fst) { return new FST(fst); } +}; + +// Convenience macro to generate static FstRegisterer instance. +#define REGISTER_FST(FST, Arc) \ + static fst::FstRegisterer> FST##_##Arc##_registerer + +// Converts an FST to the specified type. +template +Fst *Convert(const Fst &fst, const string &fst_type) { + auto *reg = FstRegister::GetRegister(); + const auto converter = reg->GetConverter(fst_type); + if (!converter) { + FSTERROR() << "Fst::Convert: Unknown FST type " << fst_type << " (arc type " + << Arc::Type() << ")"; + return nullptr; + } + return converter(fst); +} + +} // namespace fst + +#endif // FST_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/relabel.h b/projects/llm_framework/include/fst/relabel.h new file mode 100644 index 00000000..0979b077 --- /dev/null +++ b/projects/llm_framework/include/fst/relabel.h @@ -0,0 +1,472 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to relabel an FST (either on input or output). + +#ifndef FST_RELABEL_H_ +#define FST_RELABEL_H_ + +#include +#include +#include +#include + +#include + +#include +#include + + +#include + +namespace fst { + +// Relabels either the input labels or output labels. The old to +// new labels are specified using a vector of std::pair. +// Any label associations not specified are assumed to be identity +// mapping. The destination labels must be valid labels (e.g., not kNoLabel). +template +void Relabel( + MutableFst *fst, + const std::vector> + &ipairs, + const std::vector> + &opairs) { + using Label = typename Arc::Label; + const auto props = fst->Properties(kFstProperties, false); + // Constructs label-to-label maps. + const std::unordered_map input_map( + ipairs.begin(), ipairs.end()); + const std::unordered_map output_map( + opairs.begin(), opairs.end()); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + for (MutableArcIterator> aiter(fst, siter.Value()); + !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + // Relabels input. + auto it = input_map.find(arc.ilabel); + if (it != input_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Input symbol ID " << arc.ilabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.ilabel = it->second; + } + // Relabels output. + it = output_map.find(arc.olabel); + if (it != output_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Output symbol id " << arc.olabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.olabel = it->second; + } + aiter.SetValue(arc); + } + } + fst->SetProperties(RelabelProperties(props), kFstProperties); +} + +// Relabels either the input labels or output labels. The old to +// new labels are specified using pairs of old and new symbol tables. +// The tables must contain (at least) all labels on the appropriate side of the +// FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any +// missing symbol in new_i(o)symbols table. +template +void Relabel(MutableFst *fst, + const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, + const string &unknown_isymbol, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + const string &unknown_osymbol, bool attach_new_osymbols) { + using Label = typename Arc::Label; + // Constructs vectors of input-side label pairs. + std::vector> ipairs; + if (old_isymbols && new_isymbols) { + size_t num_missing_syms = 0; + Label unknown_ilabel = kNoLabel; + if (!unknown_isymbol.empty()) { + unknown_ilabel = new_isymbols->Find(unknown_isymbol); + if (unknown_ilabel == kNoLabel) { + VLOG(1) << "Input symbol '" << unknown_isymbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + + for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); + siter.Next()) { + const auto old_index = siter.Value(); + const auto symbol = siter.Symbol(); + auto new_index = new_isymbols->Find(siter.Symbol()); + if (new_index == kNoLabel) { + if (unknown_ilabel != kNoLabel) { + new_index = unknown_ilabel; + } else { + VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + ipairs.push_back(std::make_pair(old_index, new_index)); + } + if (num_missing_syms > 0) { + LOG(WARNING) << "Target symbol table missing: " << num_missing_syms + << " input symbols"; + } + if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols); + } + // Constructs vectors of output-side label pairs. + std::vector> opairs; + if (old_osymbols && new_osymbols) { + size_t num_missing_syms = 0; + Label unknown_olabel = kNoLabel; + if (!unknown_osymbol.empty()) { + unknown_olabel = new_osymbols->Find(unknown_osymbol); + if (unknown_olabel == kNoLabel) { + VLOG(1) << "Output symbol '" << unknown_osymbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + + for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); + siter.Next()) { + const auto old_index = siter.Value(); + const auto symbol = siter.Symbol(); + auto new_index = new_osymbols->Find(siter.Symbol()); + if (new_index == kNoLabel) { + if (unknown_olabel != kNoLabel) { + new_index = unknown_olabel; + } else { + VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + opairs.push_back(std::make_pair(old_index, new_index)); + } + if (num_missing_syms > 0) { + LOG(WARNING) << "Target symbol table missing: " << num_missing_syms + << " output symbols"; + } + if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols); + } + // Calls relabel using vector of relabel pairs. + Relabel(fst, ipairs, opairs); +} + +// Same as previous but no special allowance for unknown symbols. Kept +// for backward compat. +template +void Relabel(MutableFst *fst, const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + bool attach_new_osymbols) { + Relabel(fst, + old_isymbols, new_isymbols, "" /* no unknown isymbol */, + attach_new_isymbols, + old_osymbols, new_osymbols, "" /* no unknown ioymbol */, + attach_new_osymbols); +} + + +// Relabels either the input labels or output labels. The old to +// new labels are specified using symbol tables. Any label associations not +// specified are assumed to be identity mapping. +template +void Relabel(MutableFst *fst, const SymbolTable *new_isymbols, + const SymbolTable *new_osymbols) { + Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(), + new_osymbols, true); +} + +using RelabelFstOptions = CacheOptions; + +template +class RelabelFst; + +namespace internal { + +// Relabels an FST from one symbol set to another. Relabeling can either be on +// input or output space. RelabelFst implements a delayed version of the +// relabel. Arcs are relabeled on the fly and not cached; i.e., each request is +// recomputed. +template +class RelabelFstImpl : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::PushArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + RelabelFstImpl(const Fst &fst, + const std::vector> &ipairs, + const std::vector> &opairs, + const RelabelFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + input_map_(ipairs.begin(), ipairs.end()), + output_map_(opairs.begin(), opairs.end()), + relabel_input_(!ipairs.empty()), + relabel_output_(!opairs.empty()) { + SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); + SetType("relabel"); + } + + RelabelFstImpl(const Fst &fst, + const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, + const SymbolTable *old_osymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + relabel_input_(false), + relabel_output_(false) { + SetType("relabel"); + SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); + SetInputSymbols(old_isymbols); + SetOutputSymbols(old_osymbols); + if (old_isymbols && new_isymbols && + old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { + for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); + siter.Next()) { + input_map_[siter.Value()] = new_isymbols->Find(siter.Symbol()); + } + SetInputSymbols(new_isymbols); + relabel_input_ = true; + } + if (old_osymbols && new_osymbols && + old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { + for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); + siter.Next()) { + output_map_[siter.Value()] = new_osymbols->Find(siter.Symbol()); + } + SetOutputSymbols(new_osymbols); + relabel_output_ = true; + } + } + + RelabelFstImpl(const RelabelFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + input_map_(impl.input_map_), + output_map_(impl.output_map_), + relabel_input_(impl.relabel_input_), + relabel_output_(impl.relabel_output_) { + SetType("relabel"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(fst_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, fst_->Final(s)); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && fst_->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + for (ArcIterator> aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + if (relabel_input_) { + auto it = input_map_.find(arc.ilabel); + if (it != input_map_.end()) arc.ilabel = it->second; + } + if (relabel_output_) { + auto it = output_map_.find(arc.olabel); + if (it != output_map_.end()) { + arc.olabel = it->second; + } + } + PushArc(s, std::move(arc)); + } + SetArcs(s); + } + + private: + std::unique_ptr> fst_; + + std::unordered_map input_map_; + std::unordered_map output_map_; + bool relabel_input_; + bool relabel_output_; +}; + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RelabelFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::RelabelFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + RelabelFst(const Fst &fst, + const std::vector> &ipairs, + const std::vector> &opairs, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst(std::make_shared(fst, ipairs, opairs, opts)) {} + + RelabelFst(const Fst &fst, const SymbolTable *new_isymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst( + std::make_shared(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, opts)) {} + + RelabelFst(const Fst &fst, const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst(std::make_shared(fst, old_isymbols, new_isymbols, + old_osymbols, new_osymbols, + opts)) {} + + // See Fst<>::Copy() for doc. + RelabelFst(const RelabelFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc. + RelabelFst *Copy(bool safe = false) const override { + return new RelabelFst(*this, safe); + } + + void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + return GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RelabelFst &operator=(const RelabelFst &) = delete; +}; + +// Specialization for RelabelFst. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const RelabelFst &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} + + bool Done() const final { return siter_.Done(); } + + StateId Value() const final { return s_; } + + void Next() final { + if (!siter_.Done()) { + ++s_; + siter_.Next(); + } + } + + void Reset() final { + s_ = 0; + siter_.Reset(); + } + + private: + const internal::RelabelFstImpl* impl_; + StateIterator> siter_; + StateId s_; + + StateIterator(const StateIterator &) = delete; + StateIterator &operator=(const StateIterator &) = delete; +}; + +// Specialization for RelabelFst. +template +class ArcIterator> : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RelabelFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RelabelFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful alias when using StdArc. +using StdRelabelFst = RelabelFst; + +} // namespace fst + +#endif // FST_RELABEL_H_ diff --git a/projects/llm_framework/include/fst/replace-util.h b/projects/llm_framework/include/fst/replace-util.h new file mode 100644 index 00000000..42c69824 --- /dev/null +++ b/projects/llm_framework/include/fst/replace-util.h @@ -0,0 +1,629 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utility classes for the recursive replacement of FSTs (RTNs). + +#ifndef FST_REPLACE_UTIL_H_ +#define FST_REPLACE_UTIL_H_ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + + +namespace fst { + +// This specifies what labels to output on the call or return arc. Note that +// REPLACE_LABEL_INPUT and REPLACE_LABEL_OUTPUT will produce transducers when +// applied to acceptors. +enum ReplaceLabelType { + // Epsilon labels on both input and output. + REPLACE_LABEL_NEITHER = 1, + // Non-epsilon labels on input and epsilon on output. + REPLACE_LABEL_INPUT = 2, + // Epsilon on input and non-epsilon on output. + REPLACE_LABEL_OUTPUT = 3, + // Non-epsilon labels on both input and output. + REPLACE_LABEL_BOTH = 4 +}; + +// By default ReplaceUtil will copy the input label of the replace arc. +// The call_label_type and return_label_type options specify how to manage +// the labels of the call arc and the return arc of the replace FST +struct ReplaceUtilOptions { + int64 root; // Root rule for expansion. + ReplaceLabelType call_label_type; // How to label call arc. + ReplaceLabelType return_label_type; // How to label return arc. + int64 return_label; // Label to put on return arc. + + explicit ReplaceUtilOptions( + int64 root = kNoLabel, + ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT, + ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER, + int64 return_label = 0) + : root(root), + call_label_type(call_label_type), + return_label_type(return_label_type), + return_label(return_label) {} + + // For backwards compatibility. + ReplaceUtilOptions(int64 root, bool epsilon_replace_arc) + : ReplaceUtilOptions(root, + epsilon_replace_arc ? REPLACE_LABEL_NEITHER + : REPLACE_LABEL_INPUT) {} +}; + +// Every non-terminal on a path appears as the first label on that path in every +// FST associated with a given SCC of the replace dependency graph. This would +// be true if the SCC were formed from left-linear grammar rules. +constexpr uint8 kReplaceSCCLeftLinear = 0x01; +// Every non-terminal on a path appears as the final label on that path in every +// FST associated with a given SCC of the replace dependency graph. This would +// be true if the SCC were formed from right-linear grammar rules. +constexpr uint8 kReplaceSCCRightLinear = 0x02; +// The SCC in the replace dependency graph has more than one state or a +// self-loop. +constexpr uint8 kReplaceSCCNonTrivial = 0x04; + +// Defined in replace.h. +template +void Replace( + const std::vector *>> &, + MutableFst *, const ReplaceUtilOptions &); + +// Utility class for the recursive replacement of FSTs (RTNs). The user provides +// a set of label/FST pairs at construction. These are used by methods for +// testing cyclic dependencies and connectedness and doing RTN connection and +// specific FST replacement by label or for various optimization properties. The +// modified results can be obtained with the GetFstPairs() or +// GetMutableFstPairs() methods. +template +class ReplaceUtil { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstPair = std::pair *>; + using MutableFstPair = std::pair *>; + using NonTerminalHash = std::unordered_map; + + // Constructs from mutable FSTs; FST ownership is given to ReplaceUtil. + ReplaceUtil(const std::vector &fst_pairs, + const ReplaceUtilOptions &opts); + + // Constructs from FSTs; FST ownership is retained by caller. + ReplaceUtil(const std::vector &fst_pairs, + const ReplaceUtilOptions &opts); + + // Constructs from ReplaceFst internals; FST ownership is retained by caller. + ReplaceUtil(const std::vector>> &fst_array, + const NonTerminalHash &nonterminal_hash, + const ReplaceUtilOptions &opts); + + ~ReplaceUtil() { + for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i]; + } + + // True if the non-terminal dependencies are cyclic. Cyclic dependencies will + // result in an unexpandable FST. + bool CyclicDependencies() const { + GetDependencies(false); + return depprops_ & kCyclic; + } + + // Returns the strongly-connected component ID in the dependency graph of the + // replace FSTS. + StateId SCC(Label label) const { + GetDependencies(false); + const auto it = nonterminal_hash_.find(label); + if (it == nonterminal_hash_.end()) return kNoStateId; + return depscc_[it->second]; + } + + // Returns properties for the strongly-connected component in the dependency + // graph of the replace FSTs. If the SCC is kReplaceSCCLeftLinear or + // kReplaceSCCRightLinear, that SCC can be represented as finite-state despite + // any cyclic dependencies, but not by the usual replacement operation (see + // fst/extensions/pdt/replace.h). + uint8 SCCProperties(StateId scc_id) { + GetSCCProperties(); + return depsccprops_[scc_id]; + } + + // Returns true if no useless FSTs, states or transitions are present in the + // RTN. + bool Connected() const { + GetDependencies(false); + uint64 props = kAccessible | kCoAccessible; + for (Label i = 0; i < fst_array_.size(); ++i) { + if (!fst_array_[i]) continue; + if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) { + return false; + } + } + return true; + } + + // Removes useless FSTs, states and transitions from the RTN. + void Connect(); + + // Replaces FSTs specified by labels, unless there are cyclic dependencies. + void ReplaceLabels(const std::vector */> +class ReplaceFst + : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StateTable = T; + using Store = CacheStore; + using State = typename CacheStore::State; + using Impl = internal::ReplaceFstImpl; + using CacheImpl = internal::CacheBaseImpl; + + using ImplToFst::Properties; + + friend class ArcIterator>; + friend class StateIterator>; + friend class ReplaceFstMatcher; + + ReplaceFst(const std::vector *>> &fst_array, + Label root) + : ImplToFst(std::make_shared( + fst_array, ReplaceFstOptions(root))) {} + + ReplaceFst(const std::vector *>> &fst_array, + const ReplaceFstOptions &opts) + : ImplToFst(std::make_shared(fst_array, opts)) {} + + // See Fst<>::Copy() for doc. + ReplaceFst(const ReplaceFst &fst, + bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc. + ReplaceFst *Copy( + bool safe = false) const override { + return new ReplaceFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) && + ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) || + (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) { + return new ReplaceFstMatcher + (this, match_type); + } else { + VLOG(2) << "Not using replace matcher"; + return nullptr; + } + } + + bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); } + + const StateTable &GetStateTable() const { + return *GetImpl()->GetStateTable(); + } + + const Fst &GetFst(Label nonterminal) const { + return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal)); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + ReplaceFst &operator=(const ReplaceFst &) = delete; +}; + +// Specialization for ReplaceFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const ReplaceFst &fst) + : CacheStateIterator>( + fst, fst.GetMutableImpl()) {} +}; + +// Specialization for ReplaceFst, implementing optional caching. It is be used +// as follows: +// +// ReplaceFst replace; +// ArcIterator> aiter(replace, s); +// // Note: ArcIterator< Fst> is always a caching arc iterator. +// aiter.SetFlags(kArcNoCache, kArcNoCache); +// // Uses the arc iterator, no arc will be cached, no state will be expanded. +// // Arc flags can be used to decide which component of the arc need to be +// computed. +// aiter.SetFlags(kArcILabelValue, kArcValueFlags); +// // Wants the ilabel for this arc. +// aiter.Value(); // Does not compute the destination state. +// aiter.Next(); +// aiter.SetFlags(kArcNextStateValue, kArcNextStateValue); +// // Wants the ilabel and next state for this arc. +// aiter.Value(); // Does compute the destination state and inserts it +// // in the replace state table. +// // No additional arcs have been cached at this point. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + using StateTuple = typename StateTable::StateTuple; + + ArcIterator(const ReplaceFst &fst, StateId s) + : fst_(fst), + s_(s), + pos_(0), + offset_(0), + flags_(kArcValueFlags), + arcs_(nullptr), + data_flags_(0), + final_flags_(0) { + cache_data_.ref_count = nullptr; + local_data_.ref_count = nullptr; + // If FST does not support optional caching, forces caching. + if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) && + !(fst_.GetImpl()->HasArcs(s_))) { + fst_.GetMutableImpl()->Expand(s_); + } + // If state is already cached, use cached arcs array. + if (fst_.GetImpl()->HasArcs(s_)) { + (fst_.GetImpl()) + ->internal::template CacheBaseImpl< + typename CacheStore::State, + CacheStore>::InitArcIterator(s_, &cache_data_); + num_arcs_ = cache_data_.narcs; + arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + } else { // Otherwise delay decision until Value() is called. + tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_); + if (tuple_.fst_state == kNoStateId) { + num_arcs_ = 0; + } else { + // The decision to cache or not to cache has been defered until Value() + // or + // SetFlags() is called. However, the arc iterator is set up now to be + // ready for non-caching in order to keep the Value() method simple and + // efficient. + const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id); + rfst->InitArcIterator(tuple_.fst_state, &local_data_); + // arcs_ is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Computes the final arc (but not its destination state) if a final arc + // is required. + bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc( + tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue); + // Sets the arc value flags that hold for final_arc_. + final_flags_ = kArcValueFlags & ~kArcNextStateValue; + // Computes the number of arcs. + num_arcs_ = local_data_.narcs; + if (has_final_arc) ++num_arcs_; + // Sets the offset between the underlying arc positions and the + // positions + // in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + // Defers the decision to cache or not until Value() or SetFlags() is + // called. + data_flags_ = 0; + } + } + } + + ~ArcIterator() { + if (cache_data_.ref_count) --(*cache_data_.ref_count); + if (local_data_.ref_count) --(*local_data_.ref_count); + } + + void ExpandAndCache() const { + // TODO(allauzen): revisit this. + // fst_.GetImpl()->Expand(s_, tuple_, local_data_); + // (fst_.GetImpl())->CacheImpl*>::InitArcIterator(s_, + // &cache_data_); + // + fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state. + arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + offset_ = 0; // No offset. + } + + void Init() { + if (flags_ & kArcNoCache) { // If caching is disabled + // arcs_ is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Sets the arcs value flags that hold for arcs_. + data_flags_ = kArcWeightValue; + if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) { + data_flags_ |= kArcILabelValue; + } + // Sets the offset between the underlying arc positions and the positions + // in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + } else { + ExpandAndCache(); + } + } + + bool Done() const { return pos_ >= num_arcs_; } + + const Arc &Value() const { + // If data_flags_ is 0, non-caching was not requested. + if (!data_flags_) { + // TODO(allauzen): Revisit this. + if (flags_ & kArcNoCache) { + // Should never happen. + FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags"; + } + ExpandAndCache(); + } + if (pos_ - offset_ >= 0) { // The requested arc is not the final arc. + const auto &arc = arcs_[pos_ - offset_]; + if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) { + // If the value flags match the recquired value flags then returns the + // arc. + return arc; + } else { + // Otherwise, compute the corresponding arc on-the-fly. + fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_, + flags_ & kArcValueFlags); + return arc_; + } + } else { // The requested arc is the final arc. + if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) { + // If the arc value flags that hold for the final arc do not match the + // requested value flags, then + // final_arc_ needs to be updated. + fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_, + flags_ & kArcValueFlags); + final_flags_ = flags_ & kArcValueFlags; + } + return final_arc_; + } + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 flags, uint32 mask) { + // Updates the flags taking into account what flags are supported + // by the FST. + flags_ &= ~mask; + flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags()); + // If non-caching is not requested (and caching has not already been + // performed), then flush data_flags_ to request caching during the next + // call to Value(). + if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) { + if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0; + } + // If data_flags_ has been flushed but non-caching is requested before + // calling Value(), then set up the iterator for non-caching. + if ((flags & kArcNoCache) && (!data_flags_)) Init(); + } + + private: + const ReplaceFst &fst_; // Reference to the FST. + StateId s_; // State in the FST. + mutable StateTuple tuple_; // Tuple corresponding to state_. + + ssize_t pos_; // Current position. + mutable ssize_t offset_; // Offset between position in iterator and in arcs_. + ssize_t num_arcs_; // Number of arcs at state_. + uint32 flags_; // Behavorial flags for the arc iterator + mutable Arc arc_; // Memory to temporarily store computed arcs. + + mutable ArcIteratorData cache_data_; // Arc iterator data in cache. + mutable ArcIteratorData local_data_; // Arc iterator data in local FST. + + mutable const Arc *arcs_; // Array of arcs. + mutable uint32 data_flags_; // Arc value flags valid for data in arcs_. + mutable Arc final_arc_; // Final arc (when required). + mutable uint32 final_flags_; // Arc value flags valid for final_arc_. + + ArcIterator(const ArcIterator &) = delete; + ArcIterator &operator=(const ArcIterator &) = delete; +}; + +template +class ReplaceFstMatcher : public MatcherBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FST = ReplaceFst; + using LocalMatcher = MultiEpsMatcher>>; + + using StateTuple = typename StateTable::StateTuple; + + // This makes a copy of the FST. + ReplaceFstMatcher(const ReplaceFst &fst, + MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(match_type), + current_loop_(false), + final_arc_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // This doesn't copy the FST. + ReplaceFstMatcher(const ReplaceFst *fst, + MatchType match_type) + : fst_(*fst), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(match_type), + current_loop_(false), + final_arc_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // This makes a copy of the FST. + ReplaceFstMatcher( + const ReplaceFstMatcher &matcher, + bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(matcher.match_type_), + current_loop_(false), + final_arc_(false), + loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // Creates a local matcher for each component FST in the RTN. LocalMatcher is + // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each + // non-terminal arc, since these non-terminal + // turn into epsilons on recursion. + void InitMatchers() { + const auto &fst_array = impl_->fst_array_; + matcher_.resize(fst_array.size()); + for (Label i = 0; i < fst_array.size(); ++i) { + if (fst_array[i]) { + matcher_[i].reset( + new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList)); + auto it = impl_->nonterminal_set_.begin(); + for (; it != impl_->nonterminal_set_.end(); ++it) { + matcher_[i]->AddMultiEpsLabel(*it); + } + } + } + } + + ReplaceFstMatcher *Copy( + bool safe = false) const override { + return new ReplaceFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if (match_type_ == MATCH_NONE) return match_type_; + const auto true_prop = + match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; + const auto false_prop = + match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; + const auto props = fst_.Properties(true_prop | false_prop, test); + if (props & true_prop) { + return match_type_; + } else if (props & false_prop) { + return MATCH_NONE; + } else { + return MATCH_UNKNOWN; + } + } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { return props; } + + // Sets the state from which our matching happens. + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + tuple_ = impl_->GetStateTable()->Tuple(s_); + if (tuple_.fst_state == kNoStateId) { + done_ = true; + return; + } + // Gets current matcher, used for non-epsilon matching. + current_matcher_ = matcher_[tuple_.fst_id].get(); + current_matcher_->SetState(tuple_.fst_state); + loop_.nextstate = s_; + final_arc_ = false; + } + + // Searches for label from previous set state. If label == 0, first + // hallucinate an epsilon loop; otherwise use the underlying matcher to + // search for the label or epsilons. Note since the ReplaceFst recursion + // on non-terminal arcs causes epsilon transitions to be created we use + // MultiEpsilonMatcher to search for possible matches of non-terminals. If the + // component FST + // reaches a final state we also need to add the exiting final arc. + bool Find(Label label) final { + bool found = false; + label_ = label; + if (label_ == 0 || label_ == kNoLabel) { + // Computes loop directly, avoiding Replace::ComputeArc. + if (label_ == 0) { + current_loop_ = true; + found = true; + } + // Searches for matching multi-epsilons. + final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr); + found = current_matcher_->Find(kNoLabel) || final_arc_ || found; + } else { + // Searches on a sub machine directly using sub machine matcher. + found = current_matcher_->Find(label_); + } + return found; + } + + bool Done() const final { + return !current_loop_ && !final_arc_ && current_matcher_->Done(); + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + if (final_arc_) { + impl_->ComputeFinalArc(tuple_, &arc_); + return arc_; + } + const auto &component_arc = current_matcher_->Value(); + impl_->ComputeArc(tuple_, component_arc, &arc_); + return arc_; + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + return; + } + if (final_arc_) { + final_arc_ = false; + return; + } + current_matcher_->Next(); + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + std::unique_ptr> owned_fst_; + const ReplaceFst &fst_; + internal::ReplaceFstImpl *impl_; + LocalMatcher *current_matcher_; + std::vector> matcher_; + StateId s_; // Current state. + Label label_; // Current label. + MatchType match_type_; // Supplied by caller. + mutable bool done_; + mutable bool current_loop_; // Current arc is the implicit loop. + mutable bool final_arc_; // Current arc for exiting recursion. + mutable StateTuple tuple_; // Tuple corresponding to state_. + mutable Arc arc_; + Arc loop_; + + ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete; +}; + +template +inline void ReplaceFst::InitStateIterator( + StateIteratorData *data) const { + data->base = + new StateIterator>(*this); +} + +using StdReplaceFst = ReplaceFst; + +// Recursively replaces arcs in the root FSTs with other FSTs. +// This version writes the result of replacement to an output MutableFst. +// +// Replace supports replacement of arcs in one Fst with another FST. This +// replacement is recursive. Replace takes an array of FST(s). One FST +// represents the root (or topology) machine. The root FST refers to other FSTs +// by recursively replacing arcs labeled as non-terminals with the matching +// non-terminal FST. Currently Replace uses the output symbols of the arcs to +// determine whether the arc is a non-terminal arc or not. A non-terminal can be +// any label that is not a non-zero terminal label in the output alphabet. +// +// Note that input argument is a vector of pairs. These correspond to the tuple +// of non-terminal Label and corresponding FST. +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, + ReplaceFstOptions opts = ReplaceFstOptions()) { + opts.gc = true; + opts.gc_limit = 0; // Caches only the last state for fastest copy. + *ofst = ReplaceFst(ifst_array, opts); +} + +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, const ReplaceUtilOptions &opts) { + Replace(ifst_array, ofst, ReplaceFstOptions(opts)); +} + +// For backwards compatibility. +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, typename Arc::Label root, + bool epsilon_on_replace) { + Replace(ifst_array, ofst, ReplaceFstOptions(root, epsilon_on_replace)); +} + +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, typename Arc::Label root) { + Replace(ifst_array, ofst, ReplaceFstOptions(root)); +} + +} // namespace fst + +#endif // FST_REPLACE_H_ diff --git a/projects/llm_framework/include/fst/reverse.h b/projects/llm_framework/include/fst/reverse.h new file mode 100644 index 00000000..7c7c89db --- /dev/null +++ b/projects/llm_framework/include/fst/reverse.h @@ -0,0 +1,116 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to sort arcs in an FST. + +#ifndef FST_REVERSE_H_ +#define FST_REVERSE_H_ + +#include +#include + +#include + + +namespace fst { + +// Reverses an FST. The reversed result is written to an output mutable FST. +// If A transduces string x to y with weight a, then the reverse of A +// transduces the reverse of x to the reverse of y with weight a.Reverse(). +// +// Typically, a = a.Reverse() and an arc is its own reverse (e.g., for +// TropicalWeight or LogWeight). In general, e.g., when the weights only form a +// left or right semiring, the output arc type must match the input arc type +// except having the reversed Weight type. +// +// When require_superinitial is false, a superinitial state is not created in +// the reversed FST iff the input FST has exactly one final state (which becomes +// the initial state of the reversed FST) with a final weight of semiring One, +// or if it does not belong to any cycle. When require_superinitial is true, a +// superinitial state is always created. +template +void Reverse(const Fst &ifst, MutableFst *ofst, + bool require_superinitial = true) { + using StateId = typename FromArc::StateId; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kExpanded, false)) { + ofst->ReserveStates(CountStates(ifst) + 1); + } + StateId istart = ifst.Start(); + StateId ostart = kNoStateId; + StateId offset = 0; + uint64 dfs_iprops = 0; + uint64 dfs_oprops = 0; + if (!require_superinitial) { + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (ifst.Final(s) == FromWeight::Zero()) continue; + if (ostart != kNoStateId) { + ostart = kNoStateId; + break; + } else { + ostart = s; + } + } + if (ostart != kNoStateId && ifst.Final(ostart) != FromWeight::One()) { + std::vector scc; + SccVisitor scc_visitor(&scc, nullptr, nullptr, &dfs_iprops); + DfsVisit(ifst, &scc_visitor); + if (count(scc.begin(), scc.end(), scc[ostart]) > 1) { + ostart = kNoStateId; + } else { + for (ArcIterator> aiter(ifst, ostart); !aiter.Done(); + aiter.Next()) { + if (aiter.Value().nextstate == ostart) { + ostart = kNoStateId; + break; + } + } + } + if (ostart != kNoStateId) dfs_oprops = kInitialAcyclic; + } + } + if (ostart == kNoStateId) { // Super-initial requested or needed. + ostart = ofst->AddState(); + offset = 1; + } + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto is = siter.Value(); + const auto os = is + offset; + while (ofst->NumStates() <= os) ofst->AddState(); + if (is == istart) ofst->SetFinal(os, ToWeight::One()); + const auto weight = ifst.Final(is); + if ((weight != FromWeight::Zero()) && (offset == 1)) { + const ToArc oarc(0, 0, weight.Reverse(), os); + ofst->AddArc(0, oarc); + } + for (ArcIterator> aiter(ifst, is); !aiter.Done(); + aiter.Next()) { + const auto &iarc = aiter.Value(); + const auto nos = iarc.nextstate + offset; + auto weight = iarc.weight.Reverse(); + if (!offset && (nos == ostart)) { + weight = Times(ifst.Final(ostart).Reverse(), weight); + } + const ToArc oarc(iarc.ilabel, iarc.olabel, weight, os); + while (ofst->NumStates() <= nos) ofst->AddState(); + ofst->AddArc(nos, oarc); + } + } + ofst->SetStart(ostart); + if (offset == 0 && ostart == istart) { + ofst->SetFinal(ostart, ifst.Final(ostart).Reverse()); + } + const auto iprops = ifst.Properties(kCopyProperties, false) | dfs_iprops; + const auto oprops = ofst->Properties(kFstProperties, false) | dfs_oprops; + ofst->SetProperties(ReverseProperties(iprops, offset == 1) | oprops, + kFstProperties); +} + +} // namespace fst + +#endif // FST_REVERSE_H_ diff --git a/projects/llm_framework/include/fst/reweight.h b/projects/llm_framework/include/fst/reweight.h new file mode 100644 index 00000000..64e68cb7 --- /dev/null +++ b/projects/llm_framework/include/fst/reweight.h @@ -0,0 +1,127 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to reweight an FST. + +#ifndef FST_REWEIGHT_H_ +#define FST_REWEIGHT_H_ + +#include +#include + +#include + + +namespace fst { + +enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; + +// Reweights an FST according to a vector of potentials in a given direction. +// The weight must be left distributive when reweighting towards the initial +// state and right distributive when reweighting towards the final states. +// +// An arc of weight w, with an origin state of potential p and destination state +// of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting +// torwards the initial state, and by (p \otimes w) \otimes q^-1 when +// reweighting towards the final states. +template +void Reweight(MutableFst *fst, + const std::vector &potential, + ReweightType type) { + using Weight = typename Arc::Weight; + if (fst->NumStates() == 0) return; + // TODO(kbg): Make this a compile-time static_assert once we have a pleasant + // way to "deregister" this operation for non-distributive semirings so an + // informative error message is produced. + if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) { + FSTERROR() << "Reweight: Reweighting to the final states requires " + << "Weight to be right distributive: " << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + // TODO(kbg): Make this a compile-time static_assert once we have a pleasant + // way to "deregister" this operation for non-distributive semirings so an + // informative error message is produced. + if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "Reweight: Reweighting to the initial state requires " + << "Weight to be left distributive: " << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + StateIterator> siter(*fst); + for (; !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s == potential.size()) break; + const auto &weight = potential[s]; + if (weight != Weight::Zero()) { + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + if (arc.nextstate >= potential.size()) continue; + const auto &nextweight = potential[arc.nextstate]; + if (nextweight == Weight::Zero()) continue; + if (type == REWEIGHT_TO_INITIAL) { + arc.weight = + Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT); + } + if (type == REWEIGHT_TO_FINAL) { + arc.weight = + Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT); + } + aiter.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) { + fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT)); + } + } + if (type == REWEIGHT_TO_FINAL) { + fst->SetFinal(s, Times(weight, fst->Final(s))); + } + } + // This handles elements past the end of the potentials array. + for (; !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (type == REWEIGHT_TO_FINAL) { + fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s))); + } + } + const auto startweight = fst->Start() < potential.size() + ? potential[fst->Start()] + : Weight::Zero(); + if ((startweight != Weight::One()) && (startweight != Weight::Zero())) { + if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { + const auto s = fst->Start(); + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + if (type == REWEIGHT_TO_INITIAL) { + arc.weight = Times(startweight, arc.weight); + } else { + arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT), + arc.weight); + } + aiter.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) { + fst->SetFinal(s, Times(startweight, fst->Final(s))); + } else { + fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT), + fst->Final(s))); + } + } else { + const auto s = fst->AddState(); + const auto weight = + (type == REWEIGHT_TO_INITIAL) + ? startweight + : Divide(Weight::One(), startweight, DIVIDE_RIGHT); + fst->AddArc(s, Arc(0, 0, weight, fst->Start())); + fst->SetStart(s); + } + } + fst->SetProperties(ReweightProperties(fst->Properties(kFstProperties, false)), + kFstProperties); +} + +} // namespace fst + +#endif // FST_REWEIGHT_H_ diff --git a/projects/llm_framework/include/fst/rmepsilon.h b/projects/llm_framework/include/fst/rmepsilon.h new file mode 100644 index 00000000..5135bf2d --- /dev/null +++ b/projects/llm_framework/include/fst/rmepsilon.h @@ -0,0 +1,548 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes that implemement epsilon-removal. + +#ifndef FST_RMEPSILON_H_ +#define FST_RMEPSILON_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +class RmEpsilonOptions + : public ShortestDistanceOptions> { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + bool connect; // Connect output + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + + explicit RmEpsilonOptions(Queue *queue, float delta = kShortestDelta, + bool connect = true, + Weight weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId) + : ShortestDistanceOptions>( + queue, EpsilonArcFilter(), kNoStateId, delta), + connect(connect), + weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Computation state of the epsilon-removal algorithm. +template +class RmEpsilonState { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + RmEpsilonState(const Fst &fst, std::vector *distance, + const RmEpsilonOptions &opts) + : fst_(fst), + distance_(distance), + sd_state_(fst_, distance, opts, true), + expand_id_(0) {} + + void Expand(StateId s); + + std::vector &Arcs() { return arcs_; } + + const Weight &Final() const { return final_; } + + bool Error() const { return sd_state_.Error(); } + + private: + struct Element { + Label ilabel; + Label olabel; + StateId nextstate; + + Element() {} + + Element(Label ilabel, Label olabel, StateId nexstate) + : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {} + }; + + struct ElementHash { + public: + size_t operator()(const Element &element) const { + static constexpr size_t prime0 = 7853; + static constexpr size_t prime1 = 7867; + return static_cast(element.nextstate) + + static_cast(element.ilabel) * prime0 + + static_cast(element.olabel) * prime1; + } + }; + + class ElementEqual { + public: + bool operator()(const Element &e1, const Element &e2) const { + return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) && + (e1.nextstate == e2.nextstate); + } + }; + + using ElementMap = std::unordered_map, + ElementHash, ElementEqual>; + + const Fst &fst_; + // Distance from state being expanded in epsilon-closure. + std::vector *distance_; + // Shortest distance algorithm computation state. + internal::ShortestDistanceState> sd_state_; + // Maps an element to a pair corresponding to a position in the arcs vector + // of the state being expanded. The element corresopnds to the position in + // the arcs_ vector if p.first is equal to the state being expanded. + ElementMap element_map_; + EpsilonArcFilter eps_filter_; + std::stack eps_queue_; // Queue used to visit the epsilon-closure. + std::vector visited_; // True if the state has been visited. + std::forward_list visited_states_; // List of visited states. + std::vector arcs_; // Arcs of state being expanded. + Weight final_; // Final weight of state being expanded. + StateId expand_id_; // Unique ID for each call to Expand + + RmEpsilonState(const RmEpsilonState &) = delete; + RmEpsilonState &operator=(const RmEpsilonState &) = delete; +}; + +template +void RmEpsilonState::Expand(typename Arc::StateId source) { + final_ = Weight::Zero(); + arcs_.clear(); + sd_state_.ShortestDistance(source); + if (sd_state_.Error()) return; + eps_queue_.push(source); + while (!eps_queue_.empty()) { + const auto state = eps_queue_.top(); + eps_queue_.pop(); + while (visited_.size() <= state) visited_.push_back(false); + if (visited_[state]) continue; + visited_[state] = true; + visited_states_.push_front(state); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.weight = Times((*distance_)[state], arc.weight); + if (eps_filter_(arc)) { + while (visited_.size() <= arc.nextstate) visited_.push_back(false); + if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate); + } else { + const Element element(arc.ilabel, arc.olabel, arc.nextstate); + auto insert_result = element_map_.insert( + std::make_pair(element, std::make_pair(expand_id_, arcs_.size()))); + if (insert_result.second) { + arcs_.push_back(std::move(arc)); + } else { + if (insert_result.first->second.first == expand_id_) { + auto &weight = arcs_[insert_result.first->second.second].weight; + weight = Plus(weight, arc.weight); + } else { + insert_result.first->second.first = expand_id_; + insert_result.first->second.second = arcs_.size(); + arcs_.push_back(std::move(arc)); + } + } + } + } + final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); + } + while (!visited_states_.empty()) { + visited_[visited_states_.front()] = false; + visited_states_.pop_front(); + } + ++expand_id_; +} + +} // namespace internal + +// Removes epsilon-transitions (when both the input and output label are an +// epsilon) from a transducer. The result will be an equivalent FST that has no +// such epsilon transitions. This version modifies its input. It allows fine +// control via the options argument; see below for a simpler interface. +// +// The distance vector will be used to hold the shortest distances during the +// epsilon-closure computation. The state queue discipline and convergence delta +// are taken in the options argument. +template +void RmEpsilon(MutableFst *fst, + std::vector *distance, + const RmEpsilonOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (fst->Start() == kNoStateId) return; + // noneps_in[s] will be set to true iff s admits a non-epsilon incoming + // transition or is the start state. + std::vector noneps_in(fst->NumStates(), false); + noneps_in[fst->Start()] = true; + for (size_t i = 0; i < fst->NumStates(); ++i) { + for (ArcIterator> aiter(*fst, i); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (arc.ilabel != 0 || arc.olabel != 0) { + noneps_in[arc.nextstate] = true; + } + } + } + // States sorted in topological order when (acyclic) or generic topological + // order (cyclic). + std::vector states; + states.reserve(fst->NumStates()); + if (fst->Properties(kTopSorted, false) & kTopSorted) { + for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i); + } else if (fst->Properties(kAcyclic, false) & kAcyclic) { + std::vector order; + bool acyclic; + TopOrderVisitor top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter()); + // Sanity check: should be acyclic if property bit is set. + if (!acyclic) { + FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit"; + fst->SetProperties(kError, kError); + return; + } + states.resize(order.size()); + for (StateId i = 0; i < order.size(); i++) states[order[i]] = i; + } else { + uint64 props; + std::vector scc; + SccVisitor scc_visitor(&scc, nullptr, nullptr, &props); + DfsVisit(*fst, &scc_visitor, EpsilonArcFilter()); + std::vector first(scc.size(), kNoStateId); + std::vector next(scc.size(), kNoStateId); + for (StateId i = 0; i < scc.size(); i++) { + if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]]; + first[scc[i]] = i; + } + for (StateId i = 0; i < first.size(); i++) { + for (auto j = first[i]; j != kNoStateId; j = next[j]) { + states.push_back(j); + } + } + } + internal::RmEpsilonState rmeps_state(*fst, distance, opts); + while (!states.empty()) { + const auto state = states.back(); + states.pop_back(); + if (!noneps_in[state] && + (opts.connect || opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId)) { + continue; + } + rmeps_state.Expand(state); + fst->SetFinal(state, rmeps_state.Final()); + fst->DeleteArcs(state); + auto &arcs = rmeps_state.Arcs(); + fst->ReserveArcs(state, arcs.size()); + while (!arcs.empty()) { + fst->AddArc(state, arcs.back()); + arcs.pop_back(); + } + } + if (opts.connect || opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + for (size_t s = 0; s < fst->NumStates(); ++s) { + if (!noneps_in[s]) fst->DeleteArcs(s); + } + } + if (rmeps_state.Error()) fst->SetProperties(kError, kError); + fst->SetProperties( + RmEpsilonProperties(fst->Properties(kFstProperties, false)), + kFstProperties); + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + Prune(fst, opts.weight_threshold, opts.state_threshold); + } + if (opts.connect && opts.weight_threshold == Weight::Zero() && + opts.state_threshold == kNoStateId) { + Connect(fst); + } +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies its +// input. It has a simplified interface; see above for a version that +// allows finer control. +// +// Complexity: +// +// - Time: +// +// Unweighted: O(v^2 + ve). +// Acyclic: O(v^2 + V e). +// Tropical semiring: O(v^2 log V + ve). +// General: exponential. +// +// - Space: O(vE) +// +// where v is the number of states visited and e is the number of arcs visited. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science 13(1): 129-143. +template +void RmEpsilon(MutableFst *fst, bool connect = true, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId, + float delta = kShortestDelta) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + AutoQueue state_queue(*fst, &distance, EpsilonArcFilter()); + RmEpsilonOptions> opts( + &state_queue, delta, connect, weight_threshold, state_threshold); + RmEpsilon(fst, &distance, opts); +} + +struct RmEpsilonFstOptions : CacheOptions { + float delta; + + explicit RmEpsilonFstOptions(const CacheOptions &opts, + float delta = kShortestDelta) + : CacheOptions(opts), delta(delta) {} + + explicit RmEpsilonFstOptions(float delta = kShortestDelta) : delta(delta) {} +}; + +namespace internal { + +// Implementation of delayed RmEpsilonFst. +template +class RmEpsilonFstImpl : public CacheImpl { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::Properties; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + RmEpsilonFstImpl(const Fst &fst, const RmEpsilonFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + delta_(opts.delta), + rmeps_state_( + *fst_, &distance_, + RmEpsilonOptions>(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties( + RmEpsilonProperties(fst.Properties(kFstProperties, false), true), + kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + rmeps_state_( + *fst_, &distance_, + RmEpsilonOptions>(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(fst_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) Expand(s); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst_->Properties(kError, false) || rmeps_state_.Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + rmeps_state_.Expand(s); + SetFinal(s, rmeps_state_.Final()); + auto &arcs = rmeps_state_.Arcs(); + while (!arcs.empty()) { + PushArc(s, std::move(arcs.back())); + arcs.pop_back(); + } + SetArcs(s); + } + + private: + std::unique_ptr> fst_; + float delta_; + std::vector distance_; + FifoQueue queue_; + internal::RmEpsilonState> rmeps_state_; +}; + +} // namespace internal + +// Removes epsilon-transitions (when both the input and output label are an +// epsilon) from a transducer. The result will be an equivalent FST that has no +// such epsilon transitions. This version is a +// delayed FST. +// +// Complexity: +// +// - Time: +// Unweighted: O(v^2 + ve). +// General: exponential. +// +// - Space: O(vE) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time to visit an input state or arc is assumed and exclusive of +// caching. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science 13(1): 129-143. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RmEpsilonFst : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::RmEpsilonFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + explicit RmEpsilonFst(const Fst &fst) + : ImplToFst(std::make_shared(fst, RmEpsilonFstOptions())) {} + + RmEpsilonFst(const Fst &fst, const RmEpsilonFstOptions &opts) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RmEpsilonFst(const RmEpsilonFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. + RmEpsilonFst *Copy(bool safe = false) const override { + return new RmEpsilonFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RmEpsilonFst &operator=(const RmEpsilonFst &) = delete; +}; + +// Specialization for RmEpsilonFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const RmEpsilonFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for RmEpsilonFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RmEpsilonFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RmEpsilonFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful alias when using StdArc. +using StdRmEpsilonFst = RmEpsilonFst; + +} // namespace fst + +#endif // FST_RMEPSILON_H_ diff --git a/projects/llm_framework/include/fst/rmfinalepsilon.h b/projects/llm_framework/include/fst/rmfinalepsilon.h new file mode 100644 index 00000000..87e3a714 --- /dev/null +++ b/projects/llm_framework/include/fst/rmfinalepsilon.h @@ -0,0 +1,80 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to remove of final states that have epsilon-only input arcs. + +#ifndef FST_RMFINALEPSILON_H_ +#define FST_RMFINALEPSILON_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Removes final states that have epsilon-only input arcs. +template +void RmFinalEpsilon(MutableFst *fst) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + // Determines the coaccesibility of states. + std::vector access; + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + // Finds potential list of removable final states. These are final states that + // have no outgoing transitions or final states that have a non-coaccessible + // future. + std::unordered_set finals; + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst->Final(s) != Weight::Zero()) { + bool future_coaccess = false; + for (ArcIterator> aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (coaccess[arc.nextstate]) { + future_coaccess = true; + break; + } + } + if (!future_coaccess) finals.insert(s); + } + } + // Moves the final weight. + std::vector arcs; + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + auto weight = fst->Final(s); + arcs.clear(); + for (ArcIterator> aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + // Next state is in the list of finals. + if (finals.find(arc.nextstate) != finals.end()) { + // Sums up all epsilon arcs. + if (arc.ilabel == 0 && arc.olabel == 0) { + weight = Plus(Times(fst->Final(arc.nextstate), arc.weight), weight); + } else { + arcs.push_back(arc); + } + } else { + arcs.push_back(arc); + } + } + // If some arcs (epsilon arcs) were deleted, delete all arcs and add back + // only the non-epsilon arcs. + if (arcs.size() < fst->NumArcs(s)) { + fst->DeleteArcs(s); + fst->SetFinal(s, weight); + for (const auto &arc : arcs) fst->AddArc(s, arc); + } + } + Connect(fst); +} + +} // namespace fst + +#endif // FST_RMFINALEPSILON_H_ diff --git a/projects/llm_framework/include/fst/script/arc-class.h b/projects/llm_framework/include/fst/script/arc-class.h new file mode 100644 index 00000000..551266d7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arc-class.h @@ -0,0 +1,40 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARC_CLASS_H_ +#define FST_SCRIPT_ARC_CLASS_H_ + +#include + +namespace fst { +namespace script { + +// A struct representing an arc while ignoring arc type. It is passed as an +// argument to AddArc. + +struct ArcClass { + template + explicit ArcClass(const Arc &arc) + : ilabel(arc.ilabel), olabel(arc.olabel), weight(arc.weight), + nextstate(arc.nextstate) {} + + ArcClass(int64 ilabel, int64 olabel, const WeightClass &weight, + int64 nextstate) + : ilabel(ilabel), olabel(olabel), weight(weight), nextstate(nextstate) {} + + template + Arc GetArc() const { + return Arc(ilabel, olabel, *(weight.GetWeight()), + nextstate); + } + + int64 ilabel; + int64 olabel; + WeightClass weight; + int64 nextstate; +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARC_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/arciterator-class.h b/projects/llm_framework/include/fst/script/arciterator-class.h new file mode 100644 index 00000000..8e4ca4f8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arciterator-class.h @@ -0,0 +1,212 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARCITERATOR_CLASS_H_ +#define FST_SCRIPT_ARCITERATOR_CLASS_H_ + +#include +#include + +#include +#include + +// Scripting API support for ArcIterator. +// +// A call to Value() causes the underlying arc to be used to construct the +// associated ArcClass. + +namespace fst { +namespace script { + +// Non-mutable arc iterators. + +// Virtual interface implemented by each concrete ArcIteratorImpl. +class ArcIteratorImplBase { + public: + virtual bool Done() const = 0; + virtual uint32 Flags() const = 0; + virtual void Next() = 0; + virtual size_t Position() const = 0; + virtual void Reset() = 0; + virtual void Seek(size_t a) = 0; + virtual void SetFlags(uint32 flags, uint32 mask) = 0; + virtual ArcClass Value() const = 0; + virtual ~ArcIteratorImplBase() {} +}; + +// Templated implementation. +template +class ArcIteratorClassImpl : public ArcIteratorImplBase { + public: + explicit ArcIteratorClassImpl(const Fst &fst, int64 s) + : aiter_(fst, s) {} + + bool Done() const final { return aiter_.Done(); } + + uint32 Flags() const final { return aiter_.Flags(); } + + void Next() final { aiter_.Next(); } + + size_t Position() const final { return aiter_.Position(); } + + void Reset() final { aiter_.Reset(); } + + void Seek(size_t a) final { aiter_.Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) final { + aiter_.SetFlags(flags, mask); + } + + // This is returned by value because it has not yet been constructed, and + // is likely to participate in return-value optimization. + ArcClass Value() const final { return ArcClass(aiter_.Value()); } + + ~ArcIteratorClassImpl() final {} + + private: + ArcIterator> aiter_; +}; + +class ArcIteratorClass; + +using InitArcIteratorClassArgs = + std::tuple; + +// Untemplated user-facing class holding a templated pimpl. +class ArcIteratorClass { + public: + ArcIteratorClass(const FstClass &fst, int64 s); + + template + ArcIteratorClass(const Fst &fst, int64 s) + : impl_(new ArcIteratorClassImpl(fst, s)) {} + + bool Done() const { return impl_->Done(); } + + uint32 Flags() const { return impl_->Flags(); } + + void Next() { impl_->Next(); } + + size_t Position() const { return impl_->Position(); } + + void Reset() { impl_->Reset(); } + + void Seek(size_t a) { impl_->Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) { impl_->SetFlags(flags, mask); } + + ArcClass Value() const { return impl_->Value(); } + + template + friend void InitArcIteratorClass(InitArcIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitArcIteratorClass(InitArcIteratorClassArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::get<2>(*args)->impl_.reset( + new ArcIteratorClassImpl(fst, std::get<1>(*args))); +} + +// Mutable arc iterators. + +// Virtual interface implemented by each concrete MutableArcIteratorImpl. +class MutableArcIteratorImplBase : public ArcIteratorImplBase { + public: + virtual void SetValue(const ArcClass &) = 0; + + ~MutableArcIteratorImplBase() override {} +}; + +// Templated implementation. +template +class MutableArcIteratorClassImpl + : public MutableArcIteratorImplBase { + public: + explicit MutableArcIteratorClassImpl(MutableFst *fst, int64 s) + : aiter_(fst, s) {} + + bool Done() const final { return aiter_.Done(); } + + uint32 Flags() const final { return aiter_.Flags(); } + + void Next() final { aiter_.Next(); } + + size_t Position() const final { return aiter_.Position(); } + + void Reset() final { aiter_.Reset(); } + + void Seek(size_t a) final { aiter_.Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) final { + aiter_.SetFlags(flags, mask); + } + + void SetValue(const Arc &arc) { aiter_.SetValue(arc); } + + void SetValue(const ArcClass &ac) final { aiter_.SetValue(ac.GetArc()); } + + // This is returned by value because it has not yet been constructed, and + // is likely to participate in return-value optimization. + ArcClass Value() const final { return ArcClass(aiter_.Value()); } + + ~MutableArcIteratorClassImpl() override {} + + private: + MutableArcIterator> aiter_; +}; + +class MutableArcIteratorClass; + +using InitMutableArcIteratorClassArgs = + std::tuple; + +// Untemplated user-facing class holding a templated pimpl. +class MutableArcIteratorClass { + public: + MutableArcIteratorClass(MutableFstClass *fst, int64 s); + + template + MutableArcIteratorClass(MutableFst *fst, int64 s) + : impl_(new MutableArcIteratorClassImpl(fst, s)) {} + + bool Done() const { return impl_->Done(); } + + uint32 Flags() const { return impl_->Flags(); } + + void Next() { impl_->Next(); } + + size_t Position() const { return impl_->Position(); } + + void Reset() { impl_->Reset(); } + + void Seek(size_t a) { impl_->Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) { impl_->SetFlags(flags, mask); } + + void SetValue(const ArcClass &ac) { impl_->SetValue(ac); } + + ArcClass Value() const { return impl_->Value(); } + + template + friend void InitMutableArcIteratorClass( + InitMutableArcIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitMutableArcIteratorClass(InitMutableArcIteratorClassArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + std::get<2>(*args)->impl_.reset( + new MutableArcIteratorClassImpl(fst, std::get<1>(*args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARCITERATOR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/arcsort.h b/projects/llm_framework/include/fst/script/arcsort.h new file mode 100644 index 00000000..3e56fe5c --- /dev/null +++ b/projects/llm_framework/include/fst/script/arcsort.h @@ -0,0 +1,44 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARCSORT_H_ +#define FST_SCRIPT_ARCSORT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +enum ArcSortType { + ILABEL_SORT, + OLABEL_SORT +}; + +using ArcSortArgs = std::pair; + +template +void ArcSort(ArcSortArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + switch (std::get<1>(*args)) { + case ILABEL_SORT: { + const ILabelCompare icomp; + ArcSort(fst, icomp); + return; + } + case OLABEL_SORT: { + const OLabelCompare ocomp; + ArcSort(fst, ocomp); + return; + } + } +} + +void ArcSort(MutableFstClass *ofst, ArcSortType); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARCSORT_H_ diff --git a/projects/llm_framework/include/fst/script/arg-packs.h b/projects/llm_framework/include/fst/script/arg-packs.h new file mode 100644 index 00000000..93cd4a05 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arg-packs.h @@ -0,0 +1,37 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// std::pair and std::tuple are used for the arguments of FstClass operations. +// +// If a function with a return value is required, use the WithReturnValue +// template as follows: +// +// WithReturnValue> + +#ifndef FST_SCRIPT_ARG_PACKS_H_ +#define FST_SCRIPT_ARG_PACKS_H_ + +#include + +namespace fst { +namespace script { + +// Tack this on to an existing type to add a return value. The syntax for +// accessing the args is then slightly more stilted, as you must do an extra +// member access (since the args are stored as a member of this class). + +template +struct WithReturnValue { + // Avoid reference-to-reference if ArgTuple is a reference. + using Args = typename std::remove_reference::type; + + Retval retval; + const Args &args; + + explicit WithReturnValue(const Args &args) : args(args) {} +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARG_PACKS_H_ diff --git a/projects/llm_framework/include/fst/script/closure.h b/projects/llm_framework/include/fst/script/closure.h new file mode 100644 index 00000000..7c68604a --- /dev/null +++ b/projects/llm_framework/include/fst/script/closure.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CLOSURE_H_ +#define FST_SCRIPT_CLOSURE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ClosureArgs = std::pair; + +template +void Closure(ClosureArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Closure(fst, std::get<1>(*args)); +} + +void Closure(MutableFstClass *ofst, ClosureType closure_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CLOSURE_H_ diff --git a/projects/llm_framework/include/fst/script/compile-impl.h b/projects/llm_framework/include/fst/script/compile-impl.h new file mode 100644 index 00000000..943a0b72 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compile-impl.h @@ -0,0 +1,217 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to to compile a binary FST from textual input. + +#ifndef FST_SCRIPT_COMPILE_IMPL_H_ +#define FST_SCRIPT_COMPILE_IMPL_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Compile a binary Fst from textual input, helper class for fstcompile.cc +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template +class FstCompiler { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // WARNING: use of negative labels not recommended as it may cause conflicts. + // If add_symbols_ is true, then the symbols will be dynamically added to the + // symbol tables. This is only useful if you set the (i/o)keep flag to attach + // the final symbol table, or use the accessors. (The input symbol tables are + // const and therefore not changed.) + FstCompiler(std::istream &istrm, const string &source, // NOLINT + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, bool allow_negative_labels = false) { + std::unique_ptr misyms(isyms ? isyms->Copy() : nullptr); + std::unique_ptr mosyms(osyms ? osyms->Copy() : nullptr); + std::unique_ptr mssyms(ssyms ? ssyms->Copy() : nullptr); + Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, + ikeep, okeep, nkeep, allow_negative_labels, false); + } + + FstCompiler(std::istream &istrm, const string &source, // NOLINT + SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels, bool add_symbols) { + Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep, + allow_negative_labels, add_symbols); + } + + void Init(std::istream &istrm, const string &source, // NOLINT + SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels, bool add_symbols) { + nline_ = 0; + source_ = source; + isyms_ = isyms; + osyms_ = osyms; + ssyms_ = ssyms; + nstates_ = 0; + keep_state_numbering_ = nkeep; + allow_negative_labels_ = allow_negative_labels; + add_symbols_ = add_symbols; + bool start_state_populated = false; + char line[kLineLen]; + const string separator = FLAGS_fst_field_separator + "\n"; + while (istrm.getline(line, kLineLen)) { + ++nline_; + std::vector col; + SplitString(line, separator.c_str(), &col, true); + if (col.empty() || col[0][0] == '\0') + continue; + if (col.size() > 5 || (col.size() > 4 && accep) || + (col.size() == 3 && !accep)) { + FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_ + << ", line = " << nline_; + fst_.SetProperties(kError, kError); + return; + } + StateId s = StrToStateId(col[0]); + while (s >= fst_.NumStates()) fst_.AddState(); + if (!start_state_populated) { + fst_.SetStart(s); + start_state_populated = true; + } + + Arc arc; + StateId d = s; + switch (col.size()) { + case 1: + fst_.SetFinal(s, Weight::One()); + break; + case 2: + fst_.SetFinal(s, StrToWeight(col[1], true)); + break; + case 3: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = arc.ilabel; + arc.weight = Weight::One(); + fst_.AddArc(s, arc); + break; + case 4: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + if (accep) { + arc.olabel = arc.ilabel; + arc.weight = StrToWeight(col[3], true); + } else { + arc.olabel = StrToOLabel(col[3]); + arc.weight = Weight::One(); + } + fst_.AddArc(s, arc); + break; + case 5: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = StrToOLabel(col[3]); + arc.weight = StrToWeight(col[4], true); + fst_.AddArc(s, arc); + } + while (d >= fst_.NumStates()) fst_.AddState(); + } + if (ikeep) fst_.SetInputSymbols(isyms); + if (okeep) fst_.SetOutputSymbols(osyms); + } + + const VectorFst &Fst() const { return fst_; } + + private: + // Maximum line length in text file. + static constexpr int kLineLen = 8096; + + StateId StrToId(const char *s, SymbolTable *syms, const char *name, + bool allow_negative = false) const { + StateId n = 0; + if (syms) { + n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s); + if (n == -1 || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Symbol \"" << s + << "\" is not mapped to any integer " << name + << ", symbol table = " << syms->Name() + << ", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } else { + char *p; + n = strtoll(s, &p, 10); + if (p < s + strlen(s) || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } + return n; + } + + StateId StrToStateId(const char *s) { + StateId n = StrToId(s, ssyms_, "state ID"); + if (keep_state_numbering_) return n; + // Remaps state IDs to make dense set. + const auto it = states_.find(n); + if (it == states_.end()) { + states_[n] = nstates_; + return nstates_++; + } else { + return it->second; + } + } + + StateId StrToILabel(const char *s) const { + return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_); + } + + StateId StrToOLabel(const char *s) const { + return StrToId(s, osyms_, "arc olabel", allow_negative_labels_); + } + + Weight StrToWeight(const char *s, bool allow_zero) const { + Weight w; + std::istringstream strm(s); + strm >> w; + if (!strm || (!allow_zero && w == Weight::Zero())) { + FSTERROR() << "FstCompiler: Bad weight = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + w = Weight::NoWeight(); + } + return w; + } + + mutable VectorFst fst_; + size_t nline_; + string source_; // Text FST source name. + SymbolTable *isyms_; // ilabel symbol table (not owned). + SymbolTable *osyms_; // olabel symbol table (not owned). + SymbolTable *ssyms_; // slabel symbol table (not owned). + std::unordered_map states_; // State ID map. + StateId nstates_; // Number of seen states. + bool keep_state_numbering_; + bool allow_negative_labels_; // Not recommended; may cause conflicts. + bool add_symbols_; // Add to symbol tables on-the fly. + + FstCompiler(const FstCompiler &) = delete; + FstCompiler &operator=(const FstCompiler &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/compile.h b/projects/llm_framework/include/fst/script/compile.h new file mode 100644 index 00000000..c82ed477 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compile.h @@ -0,0 +1,98 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_COMPILE_H_ +#define FST_SCRIPT_COMPILE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +// This operation exists in two forms. 1 is a void operation which writes the +// compiled machine to disk; 2 returns an FstClass. I/O should normally be done +// using the binary format for efficiency, so users are STRONGLY ENCOURAGED to +// use 1 or to construct FSTs using the C++ FST mutation operations. + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct CompileFstInnerArgs { + std::istream &istrm; + const string &source; + const string &fst_type; + const fst::SymbolTable *isyms; + const fst::SymbolTable *osyms; + const fst::SymbolTable *ssyms; + const bool accep; + const bool ikeep; + const bool okeep; + const bool nkeep; + const bool allow_negative_labels; + + CompileFstInnerArgs(std::istream &istrm, const string &source, + const string &fst_type, const fst::SymbolTable *isyms, + const fst::SymbolTable *osyms, + const fst::SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, + bool allow_negative_labels = false) + : istrm(istrm), + source(source), + fst_type(fst_type), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accep(accep), + ikeep(ikeep), + okeep(okeep), + nkeep(nkeep), + allow_negative_labels(allow_negative_labels) {} +}; + +using CompileFstArgs = WithReturnValue; + +template +void CompileFstInternal(CompileFstArgs *args) { + using fst::Convert; + using fst::Fst; + using fst::FstCompiler; + FstCompiler fstcompiler( + args->args.istrm, args->args.source, args->args.isyms, args->args.osyms, + args->args.ssyms, args->args.accep, args->args.ikeep, args->args.okeep, + args->args.nkeep, args->args.allow_negative_labels); + const Fst *fst = &fstcompiler.Fst(); + std::unique_ptr> owned_fst; + if (args->args.fst_type != "vector") { + owned_fst.reset(Convert(*fst, args->args.fst_type)); + if (!owned_fst) { + FSTERROR() << "Failed to convert FST to desired type: " + << args->args.fst_type; + } + fst = owned_fst.get(); + } + args->retval = fst ? new FstClass(*fst) : nullptr; +} + +void CompileFst(std::istream &istrm, const string &source, const string &dest, + const string &fst_type, const string &arc_type, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep, + bool nkeep, bool allow_negative_labels); + +FstClass *CompileFstInternal(std::istream &istrm, const string &source, + const string &fst_type, const string &arc_type, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, + bool allow_negative_labels); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_H_ diff --git a/projects/llm_framework/include/fst/script/compose.h b/projects/llm_framework/include/fst/script/compose.h new file mode 100644 index 00000000..a1735803 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compose.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_COMPOSE_H_ +#define FST_SCRIPT_COMPOSE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ComposeArgs = std::tuple; + +template +void Compose(ComposeArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Compose(ifst1, ifst2, ofst, opts); +} + +void Compose(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/script/concat.h b/projects/llm_framework/include/fst/script/concat.h new file mode 100644 index 00000000..4bf8dc61 --- /dev/null +++ b/projects/llm_framework/include/fst/script/concat.h @@ -0,0 +1,40 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONCAT_H_ +#define FST_SCRIPT_CONCAT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ConcatArgs1 = std::pair; + +template +void Concat(ConcatArgs1 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + const Fst &ifst = *(std::get<1>(*args).GetFst()); + Concat(ofst, ifst); +} + +using ConcatArgs2 = std::pair; + +template +void Concat(ConcatArgs2 *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Concat(ifst, ofst); +} + +void Concat(MutableFstClass *ofst, const FstClass &ifst); + +void Concat(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONCAT_H_ diff --git a/projects/llm_framework/include/fst/script/connect.h b/projects/llm_framework/include/fst/script/connect.h new file mode 100644 index 00000000..030102ac --- /dev/null +++ b/projects/llm_framework/include/fst/script/connect.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONNECT_H_ +#define FST_SCRIPT_CONNECT_H_ + +#include +#include + +namespace fst { +namespace script { + +template +void Connect(MutableFstClass *fst) { + Connect(fst->GetMutableFst()); +} + +void Connect(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONNECT_H_ diff --git a/projects/llm_framework/include/fst/script/convert.h b/projects/llm_framework/include/fst/script/convert.h new file mode 100644 index 00000000..1a6eeaa3 --- /dev/null +++ b/projects/llm_framework/include/fst/script/convert.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONVERT_H_ +#define FST_SCRIPT_CONVERT_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using ConvertInnerArgs = std::pair; + +using ConvertArgs = WithReturnValue; + +template +void Convert(ConvertArgs *args) { + const Fst &fst = *(std::get<0>(args->args).GetFst()); + const string &new_type = std::get<1>(args->args); + std::unique_ptr> result(Convert(fst, new_type)); + args->retval = result ? new FstClass(*result) : nullptr; +} + +FstClass *Convert(const FstClass &fst, const string &new_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONVERT_H_ diff --git a/projects/llm_framework/include/fst/script/decode.h b/projects/llm_framework/include/fst/script/decode.h new file mode 100644 index 00000000..09f25391 --- /dev/null +++ b/projects/llm_framework/include/fst/script/decode.h @@ -0,0 +1,49 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DECODE_H_ +#define FST_SCRIPT_DECODE_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using DecodeArgs1 = std::pair; + +template +void Decode(DecodeArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + std::unique_ptr> decoder( + EncodeMapper::Read(std::get<1>(*args), DECODE)); + if (!decoder) { + fst->SetProperties(kError, kError); + return; + } + Decode(fst, *decoder); +} + +using DecodeArgs2 = std::pair; + +template +void Decode(DecodeArgs2 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const EncodeMapper &encoder = + *(std::get<1>(*args).GetEncodeMapper()); + Decode(fst, encoder); +} + +void Decode(MutableFstClass *fst, const string &coder_fname); + +void Decode(MutableFstClass *fst, const EncodeMapperClass &encoder); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DECODE_H_ diff --git a/projects/llm_framework/include/fst/script/determinize.h b/projects/llm_framework/include/fst/script/determinize.h new file mode 100644 index 00000000..383a8fe4 --- /dev/null +++ b/projects/llm_framework/include/fst/script/determinize.h @@ -0,0 +1,59 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DETERMINIZE_H_ +#define FST_SCRIPT_DETERMINIZE_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +struct DeterminizeOptions { + const float delta; + const WeightClass &weight_threshold; + const int64 state_threshold; + const int64 subsequential_label; + const DeterminizeType det_type; + const bool increment_subsequential_label; + + DeterminizeOptions(float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + int64 subsequential_label = 0, + DeterminizeType det_type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false) + : delta(delta), + weight_threshold(weight_threshold), + state_threshold(state_threshold), + subsequential_label(subsequential_label), + det_type(det_type), + increment_subsequential_label(increment_subsequential_label) {} +}; + +using DeterminizeArgs = std::tuple; + +template +void Determinize(DeterminizeArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + const auto weight_threshold = *(opts.weight_threshold.GetWeight()); + const fst::DeterminizeOptions detargs(opts.delta, weight_threshold, + opts.state_threshold, opts.subsequential_label, opts.det_type, + opts.increment_subsequential_label); + Determinize(ifst, ofst, detargs); +} + +void Determinize(const FstClass &ifst, MutableFstClass *ofst, + const DeterminizeOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DETERMINIZE_H_ diff --git a/projects/llm_framework/include/fst/script/difference.h b/projects/llm_framework/include/fst/script/difference.h new file mode 100644 index 00000000..7af6200a --- /dev/null +++ b/projects/llm_framework/include/fst/script/difference.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DIFFERENCE_H_ +#define FST_SCRIPT_DIFFERENCE_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using DifferenceArgs = std::tuple; + +template +void Difference(DifferenceArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Difference(ifst1, ifst2, ofst, opts); +} + +void Difference(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DIFFERENCE_H_ diff --git a/projects/llm_framework/include/fst/script/disambiguate.h b/projects/llm_framework/include/fst/script/disambiguate.h new file mode 100644 index 00000000..acc1fba2 --- /dev/null +++ b/projects/llm_framework/include/fst/script/disambiguate.h @@ -0,0 +1,54 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DISAMBIGUATE_H_ +#define FST_SCRIPT_DISAMBIGUATE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +struct DisambiguateOptions { + const float delta; + const WeightClass &weight_threshold; + const int64 state_threshold; + const int64 subsequential_label; + + DisambiguateOptions(float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + int64 subsequential_label = 0) + : delta(delta), + weight_threshold(weight_threshold), + state_threshold(state_threshold), + subsequential_label(subsequential_label) {} +}; + +using DisambiguateArgs = std::tuple; + +template +void Disambiguate(DisambiguateArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + const auto weight_threshold = *(opts.weight_threshold.GetWeight()); + const fst::DisambiguateOptions disargs(opts.delta, weight_threshold, + opts.state_threshold, + opts.subsequential_label); + Disambiguate(ifst, ofst, disargs); +} + +void Disambiguate(const FstClass &ifst, MutableFstClass *ofst, + const DisambiguateOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DISAMBIGUATE_H_ diff --git a/projects/llm_framework/include/fst/script/draw-impl.h b/projects/llm_framework/include/fst/script/draw-impl.h new file mode 100644 index 00000000..f204b2e6 --- /dev/null +++ b/projects/llm_framework/include/fst/script/draw-impl.h @@ -0,0 +1,227 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to draw a binary FST by producing a text file in dot format, a helper +// class to fstdraw.cc. + +#ifndef FST_SCRIPT_DRAW_IMPL_H_ +#define FST_SCRIPT_DRAW_IMPL_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Print a binary FST in GraphViz textual format (helper class for fstdraw.cc). +// WARNING: Stand-alone use not recommend. +template +class FstDrawer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + FstDrawer(const Fst &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, bool show_weight_one) + : fst_(fst), + isyms_(isyms), + osyms_(osyms), + ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), + ostrm_(nullptr), + title_(title), + width_(width), + height_(height), + portrait_(portrait), + vertical_(vertical), + ranksep_(ranksep), + nodesep_(nodesep), + fontsize_(fontsize), + precision_(precision), + float_format_(float_format), + show_weight_one_(show_weight_one) {} + + // Draws FST to an output buffer. + void Draw(std::ostream *strm, const string &dest) { + ostrm_ = strm; + SetStreamState(ostrm_); + dest_ = dest; + StateId start = fst_.Start(); + if (start == kNoStateId) return; + PrintString("digraph FST {\n"); + if (vertical_) { + PrintString("rankdir = BT;\n"); + } else { + PrintString("rankdir = LR;\n"); + } + PrintString("size = \""); + Print(width_); + PrintString(","); + Print(height_); + PrintString("\";\n"); + if (!dest_.empty()) PrintString("label = \"" + title_ + "\";\n"); + PrintString("center = 1;\n"); + if (portrait_) { + PrintString("orientation = Portrait;\n"); + } else { + PrintString("orientation = Landscape;\n"); + } + PrintString("ranksep = \""); + Print(ranksep_); + PrintString("\";\n"); + PrintString("nodesep = \""); + Print(nodesep_); + PrintString("\";\n"); + // Initial state first. + DrawState(start); + for (StateIterator> siter(fst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s != start) DrawState(s); + } + PrintString("}\n"); + } + + private: + void SetStreamState(std::ostream* strm) const { + strm->precision(precision_); + if (float_format_ == "e") + strm->setf(std::ios_base::scientific, std::ios_base::floatfield); + if (float_format_ == "f") + strm->setf(std::ios_base::fixed, std::ios_base::floatfield); + // O.w. defaults to "g" per standard lib. + } + + void PrintString(const string &str) const { *ostrm_ << str; } + + // Escapes backslash and double quote if these occur in the string. Dot will + // not deal gracefully with these if they are not escaped. + static string Escape(const string &str) { + string ns; + for (char c : str) { + if (c == '\\' || c == '"') ns.push_back('\\'); + ns.push_back(c); + } + return ns; + } + + void PrintId(StateId id, const SymbolTable *syms, const char *name) const { + if (syms) { + auto symbol = syms->Find(id); + if (symbol.empty()) { + FSTERROR() << "FstDrawer: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } + PrintString(Escape(symbol)); + } else { + PrintString(std::to_string(id)); + } + } + + void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); } + + void PrintILabel(Label label) const { + PrintId(label, isyms_, "arc input label"); + } + + void PrintOLabel(Label label) const { + PrintId(label, osyms_, "arc output label"); + } + + void PrintWeight(Weight w) const { + // Weight may have double quote characters in it, so escape it. + PrintString(Escape(ToString(w))); + } + + template + void Print(T t) const { *ostrm_ << t; } + + template + string ToString(T t) const { + std::stringstream ss; + SetStreamState(&ss); + ss << t; + return ss.str(); + } + + void DrawState(StateId s) const { + Print(s); + PrintString(" [label = \""); + PrintStateId(s); + const auto weight = fst_.Final(s); + if (weight != Weight::Zero()) { + if (show_weight_one_ || (weight != Weight::One())) { + PrintString("/"); + PrintWeight(weight); + } + PrintString("\", shape = doublecircle,"); + } else { + PrintString("\", shape = circle,"); + } + if (s == fst_.Start()) { + PrintString(" style = bold,"); + } else { + PrintString(" style = solid,"); + } + PrintString(" fontsize = "); + Print(fontsize_); + PrintString("]\n"); + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + PrintString("\t"); + Print(s); + PrintString(" -> "); + Print(arc.nextstate); + PrintString(" [label = \""); + PrintILabel(arc.ilabel); + if (!accep_) { + PrintString(":"); + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || (arc.weight != Weight::One())) { + PrintString("/"); + PrintWeight(arc.weight); + } + PrintString("\", fontsize = "); + Print(fontsize_); + PrintString("];\n"); + } + } + + const Fst &fst_; + const SymbolTable *isyms_; // ilabel symbol table. + const SymbolTable *osyms_; // olabel symbol table. + const SymbolTable *ssyms_; // slabel symbol table. + bool accep_; // Print as acceptor when possible. + std::ostream *ostrm_; // Drawn FST destination. + string dest_; // Drawn FST destination name. + + string title_; + float width_; + float height_; + bool portrait_; + bool vertical_; + float ranksep_; + float nodesep_; + int fontsize_; + int precision_; + string float_format_; + bool show_weight_one_; + + FstDrawer(const FstDrawer &) = delete; + FstDrawer &operator=(const FstDrawer &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_DRAW_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/draw.h b/projects/llm_framework/include/fst/script/draw.h new file mode 100644 index 00000000..cb37df1e --- /dev/null +++ b/projects/llm_framework/include/fst/script/draw.h @@ -0,0 +1,85 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DRAW_H_ +#define FST_SCRIPT_DRAW_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstDrawerArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accep; + const string &title; + const float width; + const float height; + const bool portrait; + const bool vertical; + const float ranksep; + const float nodesep; + const int fontsize; + const int precision; + const string &float_format; // NOLINT + const bool show_weight_one; + std::ostream *ostrm; + const string &dest; + + FstDrawerArgs(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, + bool show_weight_one, std::ostream *ostrm, const string &dest) + : fst(fst), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accep(accep), + title(title), + width(width), + height(height), + portrait(portrait), + vertical(vertical), + ranksep(ranksep), + nodesep(nodesep), + fontsize(fontsize), + precision(precision), + float_format(float_format), + show_weight_one(show_weight_one), + ostrm(ostrm), + dest(dest) {} +}; + +template +void DrawFst(FstDrawerArgs *args) { + const Fst &fst = *(args->fst.GetFst()); + FstDrawer fstdrawer(fst, args->isyms, args->osyms, args->ssyms, + args->accep, args->title, args->width, args->height, args->portrait, + args->vertical, args->ranksep, args->nodesep, args->fontsize, + args->precision, args->float_format, args->show_weight_one); + fstdrawer.Draw(args->ostrm, args->dest); +} + +void DrawFst(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, bool show_weight_one, + std::ostream *ostrm, const string &dest); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DRAW_H_ diff --git a/projects/llm_framework/include/fst/script/encode.h b/projects/llm_framework/include/fst/script/encode.h new file mode 100644 index 00000000..6a869680 --- /dev/null +++ b/projects/llm_framework/include/fst/script/encode.h @@ -0,0 +1,51 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ENCODE_H_ +#define FST_SCRIPT_ENCODE_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EncodeArgs1 = std::tuple; + +template +void Encode(EncodeArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const string &coder_fname = std::get<3>(*args); + // If true, reuse encode from disk. If false, make a new encoder and just use + // the filename argument as the destination state. + std::unique_ptr> encoder( + std::get<2>(*args) ? EncodeMapper::Read(coder_fname, ENCODE) + : new EncodeMapper(std::get<1>(*args), ENCODE)); + Encode(fst, encoder.get()); + if (!std::get<2>(*args)) encoder->Write(coder_fname); +} + +using EncodeArgs2 = std::pair; + +template +void Encode(EncodeArgs2 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + EncodeMapper *encoder = std::get<1>(*args)->GetEncodeMapper(); + Encode(fst, encoder); +} + +void Encode(MutableFstClass *fst, uint32 flags, bool reuse_encoder, + const string &coder_fname); + +void Encode(MutableFstClass *fst, EncodeMapperClass *encoder); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ENCODE_H_ diff --git a/projects/llm_framework/include/fst/script/encodemapper-class.h b/projects/llm_framework/include/fst/script/encodemapper-class.h new file mode 100644 index 00000000..b62824f1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/encodemapper-class.h @@ -0,0 +1,169 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_ +#define FST_SCRIPT_ENCODEMAPPER_CLASS_H_ + +#include +#include +#include + +#include +#include +#include + +// Scripting API support for EncodeMapper. + +namespace fst { +namespace script { + +// Virtual interface implemented by each concrete EncodeMapperClassImpl. +class EncodeMapperImplBase { + public: + // Returns an encoded ArcClass. + virtual ArcClass operator()(const ArcClass &a) = 0; + virtual const string &ArcType() const = 0; + virtual uint32 Flags() const = 0; + virtual uint64 Properties(uint64 inprops) = 0; + virtual EncodeType Type() const = 0; + virtual const SymbolTable *InputSymbols() const = 0; + virtual const SymbolTable *OutputSymbols() const = 0; + virtual void SetInputSymbols(const SymbolTable *syms) = 0; + virtual void SetOutputSymbols(const SymbolTable *syms) = 0; + virtual const string &WeightType() const = 0; + virtual ~EncodeMapperImplBase() {} +}; + +// Templated implementation. +template +class EncodeMapperClassImpl : public EncodeMapperImplBase { + public: + EncodeMapperClassImpl(uint32 flags, EncodeType type) + : encoder_(flags, type) {} + + ArcClass operator()(const ArcClass &a) final; + + const string &ArcType() const final { return Arc::Type(); } + + uint32 Flags() const final { return encoder_.Flags(); } + + uint64 Properties(uint64 inprops) final { + return encoder_.Properties(inprops); + } + + EncodeType Type() const final { return encoder_.Type(); } + + const SymbolTable *InputSymbols() const final { + return encoder_.InputSymbols(); + } + + const SymbolTable *OutputSymbols() const final { + return encoder_.OutputSymbols(); + } + + void SetInputSymbols(const SymbolTable *syms) final { + encoder_.SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) final { + encoder_.SetOutputSymbols(syms); + } + + const string &WeightType() const final { return Arc::Weight::Type(); } + + ~EncodeMapperClassImpl() override {} + + EncodeMapper *GetImpl() const { return &encoder_; } + + EncodeMapper *GetImpl() { return &encoder_; } + + private: + EncodeMapper encoder_; +}; + +// This is returned by value because it is very likely to undergo return-value +// optimization. +template +inline ArcClass EncodeMapperClassImpl::operator()(const ArcClass &a) { + Arc arc(a.ilabel, a.olabel, *(a.weight.GetWeight()), + a.nextstate); + return ArcClass(encoder_(arc)); +} + +class EncodeMapperClass; + +using InitEncodeMapperClassArgs = + std::tuple; + +class EncodeMapperClass { + public: + EncodeMapperClass(const string &arc_type, uint32 flags, EncodeType type); + + template + EncodeMapperClass(uint32 flags, EncodeType type) + : impl_(new EncodeMapperClassImpl(flags, type)) {} + + ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); } + + const string &ArcType() const { return impl_->ArcType(); } + + uint32 Flags() const { return impl_->Flags(); } + + uint64 Properties(uint64 inprops) { return impl_->Properties(inprops); } + + EncodeType Type() const { return impl_->Type(); } + + const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); } + + const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable *syms) { + impl_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) { + impl_->SetOutputSymbols(syms); + } + + const string &WeightType() const { return impl_->WeightType(); } + + template + friend void InitEncodeMapperClass(InitEncodeMapperClassArgs *args); + + // Naturally, this exists in non-const and const forms. Encoding arcs or FSTs + // mutates the underlying encoder; decoding them does not. + + template + EncodeMapper *GetEncodeMapper() { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + template + const EncodeMapper *GetEncodeMapper() const { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + private: + std::unique_ptr impl_; +}; + +template +void InitEncodeMapperClass(InitEncodeMapperClassArgs *args) { + std::get<2>(*args)->impl_.reset( + new EncodeMapperClassImpl(std::get<0>(*args), std::get<1>(*args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/epsnormalize.h b/projects/llm_framework/include/fst/script/epsnormalize.h new file mode 100644 index 00000000..b55fefae --- /dev/null +++ b/projects/llm_framework/include/fst/script/epsnormalize.h @@ -0,0 +1,31 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EPSNORMALIZE_H_ +#define FST_SCRIPT_EPSNORMALIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using EpsNormalizeArgs = std::tuple; + +template +void EpsNormalize(EpsNormalizeArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + EpsNormalize(ifst, ofst, std::get<2>(*args)); +} + +void EpsNormalize(const FstClass &ifst, MutableFstClass *ofst, + EpsNormalizeType norm_type = EPS_NORM_INPUT); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EPSNORMALIZE_H_ diff --git a/projects/llm_framework/include/fst/script/equal.h b/projects/llm_framework/include/fst/script/equal.h new file mode 100644 index 00000000..79ea9aa4 --- /dev/null +++ b/projects/llm_framework/include/fst/script/equal.h @@ -0,0 +1,32 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EQUAL_H_ +#define FST_SCRIPT_EQUAL_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EqualInnerArgs = std::tuple; + +using EqualArgs = WithReturnValue; + +template +void Equal(EqualArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Equal(fst1, fst2, std::get<2>(args->args)); +} + +bool Equal(const FstClass &fst1, const FstClass &fst2, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EQUAL_H_ diff --git a/projects/llm_framework/include/fst/script/equivalent.h b/projects/llm_framework/include/fst/script/equivalent.h new file mode 100644 index 00000000..7cdff45e --- /dev/null +++ b/projects/llm_framework/include/fst/script/equivalent.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EQUIVALENT_H_ +#define FST_SCRIPT_EQUIVALENT_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EquivalentInnerArgs = std::tuple; + +using EquivalentArgs = WithReturnValue; + +template +void Equivalent(EquivalentArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Equivalent(fst1, fst2, std::get<2>(args->args)); +} + +bool Equivalent(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/script/fst-class.h b/projects/llm_framework/include/fst/script/fst-class.h new file mode 100644 index 00000000..07319fc7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fst-class.h @@ -0,0 +1,530 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_FST_CLASS_H_ +#define FST_SCRIPT_FST_CLASS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// Classes to support "boxing" all existing types of FST arcs in a single +// FstClass which hides the arc types. This allows clients to load +// and work with FSTs without knowing the arc type. These classes are only +// recommended for use in high-level scripting applications. Most users should +// use the lower-level templated versions corresponding to these classes. + +namespace fst { +namespace script { + +// Abstract base class defining the set of functionalities implemented in all +// impls and passed through by all bases. Below FstClassBase the class +// hierarchy bifurcates; FstClassImplBase serves as the base class for all +// implementations (of which FstClassImpl is currently the only one) and +// FstClass serves as the base class for all interfaces. + +class FstClassBase { + public: + virtual const string &ArcType() const = 0; + virtual WeightClass Final(int64) const = 0; + virtual const string &FstType() const = 0; + virtual const SymbolTable *InputSymbols() const = 0; + virtual size_t NumArcs(int64) const = 0; + virtual size_t NumInputEpsilons(int64) const = 0; + virtual size_t NumOutputEpsilons(int64) const = 0; + virtual const SymbolTable *OutputSymbols() const = 0; + virtual uint64 Properties(uint64, bool) const = 0; + virtual int64 Start() const = 0; + virtual const string &WeightType() const = 0; + virtual bool ValidStateId(int64) const = 0; + virtual bool Write(const string &) const = 0; + virtual bool Write(std::ostream &, const string &) const = 0; + virtual ~FstClassBase() {} +}; + +// Adds all the MutableFst methods. +class FstClassImplBase : public FstClassBase { + public: + virtual bool AddArc(int64, const ArcClass &) = 0; + virtual int64 AddState() = 0; + virtual FstClassImplBase *Copy() = 0; + virtual bool DeleteArcs(int64, size_t) = 0; + virtual bool DeleteArcs(int64) = 0; + virtual bool DeleteStates(const std::vector &) = 0; + virtual void DeleteStates() = 0; + virtual SymbolTable *MutableInputSymbols() = 0; + virtual SymbolTable *MutableOutputSymbols() = 0; + virtual int64 NumStates() const = 0; + virtual bool ReserveArcs(int64, size_t) = 0; + virtual void ReserveStates(int64) = 0; + virtual void SetInputSymbols(SymbolTable *) = 0; + virtual bool SetFinal(int64, const WeightClass &) = 0; + virtual void SetOutputSymbols(SymbolTable *) = 0; + virtual void SetProperties(uint64, uint64) = 0; + virtual bool SetStart(int64) = 0; + ~FstClassImplBase() override {} +}; + +// Containiner class wrapping an Fst, hiding its arc type. Whether this +// Fst pointer refers to a special kind of FST (e.g. a MutableFst) is +// known by the type of interface class that owns the pointer to this +// container. + +template +class FstClassImpl : public FstClassImplBase { + public: + explicit FstClassImpl(Fst *impl, bool should_own = false) + : impl_(should_own ? impl : impl->Copy()) {} + + explicit FstClassImpl(const Fst &impl) : impl_(impl.Copy()) {} + + // Warning: calling this method casts the FST to a mutable FST. + bool AddArc(int64 s, const ArcClass &ac) final { + if (!ValidStateId(s)) return false; + // Note that we do not check that the destination state is valid, so users + // can add arcs before they add the corresponding states. Verify can be + // used to determine whether any arc has a nonexisting destination. + Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight(), + ac.nextstate); + static_cast *>(impl_.get())->AddArc(s, arc); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + int64 AddState() final { + return static_cast *>(impl_.get())->AddState(); + } + + const string &ArcType() const final { return Arc::Type(); } + + FstClassImpl *Copy() final { return new FstClassImpl(impl_.get()); } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteArcs(int64 s, size_t n) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->DeleteArcs(s, n); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteArcs(int64 s) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->DeleteArcs(s); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteStates(const std::vector &dstates) final { + for (const auto &state : dstates) + if (!ValidStateId(state)) return false; + // Warning: calling this method with any integers beyond the precision of + // the underlying FST will result in truncation. + std::vector typed_dstates(dstates.size()); + std::copy(dstates.begin(), dstates.end(), typed_dstates.begin()); + static_cast *>(impl_.get())->DeleteStates(typed_dstates); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void DeleteStates() final { + static_cast *>(impl_.get())->DeleteStates(); + } + + WeightClass Final(int64 s) const final { + if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType()); + WeightClass w(impl_->Final(s)); + return w; + } + + const string &FstType() const final { return impl_->Type(); } + + const SymbolTable *InputSymbols() const final { + return impl_->InputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + SymbolTable *MutableInputSymbols() final { + return static_cast *>(impl_.get())->MutableInputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + SymbolTable *MutableOutputSymbols() final { + return static_cast *>(impl_.get())->MutableOutputSymbols(); + } + + // Signals failure by returning size_t max. + size_t NumArcs(int64 s) const final { + return ValidStateId(s) ? impl_->NumArcs(s) + : std::numeric_limits::max(); + } + + // Signals failure by returning size_t max. + size_t NumInputEpsilons(int64 s) const final { + return ValidStateId(s) ? impl_->NumInputEpsilons(s) + : std::numeric_limits::max(); + } + + // Signals failure by returning size_t max. + size_t NumOutputEpsilons(int64 s) const final { + return ValidStateId(s) ? impl_->NumOutputEpsilons(s) + : std::numeric_limits::max(); + } + + // Warning: calling this method casts the FST to a mutable FST. + int64 NumStates() const final { + return static_cast *>(impl_.get())->NumStates(); + } + + uint64 Properties(uint64 mask, bool test) const final { + return impl_->Properties(mask, test); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool ReserveArcs(int64 s, size_t n) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->ReserveArcs(s, n); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void ReserveStates(int64 s) final { + static_cast *>(impl_.get())->ReserveStates(s); + } + + const SymbolTable *OutputSymbols() const final { + return impl_->OutputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetInputSymbols(SymbolTable *isyms) final { + static_cast *>(impl_.get())->SetInputSymbols(isyms); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool SetFinal(int64 s, const WeightClass &weight) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get()) + ->SetFinal(s, *weight.GetWeight()); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetOutputSymbols(SymbolTable *osyms) final { + static_cast *>(impl_.get())->SetOutputSymbols(osyms); + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetProperties(uint64 props, uint64 mask) final { + static_cast *>(impl_.get())->SetProperties(props, mask); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool SetStart(int64 s) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->SetStart(s); + return true; + } + + int64 Start() const final { return impl_->Start(); } + + bool ValidStateId(int64 s) const final { + // This cowardly refuses to count states if the FST is not yet expanded. + if (!Properties(kExpanded, true)) { + FSTERROR() << "Cannot get number of states for unexpanded FST"; + return false; + } + // If the FST is already expanded, CountStates calls NumStates. + if (s < 0 || s >= CountStates(*impl_)) { + FSTERROR() << "State ID " << s << " not valid"; + return false; + } + return true; + } + + const string &WeightType() const final { return Arc::Weight::Type(); } + + bool Write(const string &fname) const final { return impl_->Write(fname); } + + bool Write(std::ostream &ostr, const string &fname) const final { + const FstWriteOptions opts(fname); + return impl_->Write(ostr, opts); + } + + ~FstClassImpl() override {} + + Fst *GetImpl() const { return impl_.get(); } + + private: + std::unique_ptr> impl_; +}; + +// BASE CLASS DEFINITIONS + +class MutableFstClass; + +class FstClass : public FstClassBase { + public: + FstClass() : impl_(nullptr) {} + + template + explicit FstClass(const Fst &fst) : impl_(new FstClassImpl(fst)) {} + + FstClass(const FstClass &other) + : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {} + + FstClass &operator=(const FstClass &other) { + impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy()); + return *this; + } + + WeightClass Final(int64 s) const final { return impl_->Final(s); } + + const string &ArcType() const final { return impl_->ArcType(); } + + const string &FstType() const final { return impl_->FstType(); } + + const SymbolTable *InputSymbols() const final { + return impl_->InputSymbols(); + } + + size_t NumArcs(int64 s) const final { return impl_->NumArcs(s); } + + size_t NumInputEpsilons(int64 s) const final { + return impl_->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(int64 s) const final { + return impl_->NumOutputEpsilons(s); + } + + const SymbolTable *OutputSymbols() const final { + return impl_->OutputSymbols(); + } + + uint64 Properties(uint64 mask, bool test) const final { + // Special handling for FSTs with a null impl. + if (!impl_) return kError & mask; + return impl_->Properties(mask, test); + } + + static FstClass *Read(const string &fname); + + static FstClass *Read(std::istream &istrm, const string &source); + + int64 Start() const final { return impl_->Start(); } + + bool ValidStateId(int64 s) const final { return impl_->ValidStateId(s); } + + const string &WeightType() const final { return impl_->WeightType(); } + + // Helper that logs an ERROR if the weight type of an FST and a WeightClass + // don't match. + + bool WeightTypesMatch(const WeightClass &weight, const string &op_name) const; + + bool Write(const string &fname) const final { return impl_->Write(fname); } + + bool Write(std::ostream &ostr, const string &fname) const final { + return impl_->Write(ostr, fname); + } + + ~FstClass() override {} + + // These methods are required by IO registration. + + template + static FstClassImplBase *Convert(const FstClass &other) { + FSTERROR() << "Doesn't make sense to convert any class to type FstClass"; + return nullptr; + } + + template + static FstClassImplBase *Create() { + FSTERROR() << "Doesn't make sense to create an FstClass with a " + << "particular arc type"; + return nullptr; + } + + template + const Fst *GetFst() const { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + FstClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + template + static FstClass *Read(std::istream &stream, const FstReadOptions &opts) { + if (!opts.header) { + LOG(ERROR) << "FstClass::Read: Options header not specified"; + return nullptr; + } + const FstHeader &hdr = *opts.header; + if (hdr.Properties() & kMutable) { + return ReadTypedFst>(stream, opts); + } else { + return ReadTypedFst>(stream, opts); + } + } + + protected: + explicit FstClass(FstClassImplBase *impl) : impl_(impl) {} + + const FstClassImplBase *GetImpl() const { return impl_.get(); } + + FstClassImplBase *GetImpl() { return impl_.get(); } + + // Generic template method for reading an arc-templated FST of type + // UnderlyingT, and returning it wrapped as FstClassT, with appropriat + // error checking. Called from arc-templated Read() static methods. + template + static FstClassT *ReadTypedFst(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr u(UnderlyingT::Read(stream, opts)); + return u ? new FstClassT(*u) : nullptr; + } + + private: + std::unique_ptr impl_; +}; + +// Specific types of FstClass with special properties + +class MutableFstClass : public FstClass { + public: + bool AddArc(int64 s, const ArcClass &ac) { + if (!WeightTypesMatch(ac.weight, "AddArc")) return false; + return GetImpl()->AddArc(s, ac); + } + + int64 AddState() { return GetImpl()->AddState(); } + + bool DeleteArcs(int64 s, size_t n) { return GetImpl()->DeleteArcs(s, n); } + + bool DeleteArcs(int64 s) { return GetImpl()->DeleteArcs(s); } + + bool DeleteStates(const std::vector &dstates) { + return GetImpl()->DeleteStates(dstates); + } + + void DeleteStates() { GetImpl()->DeleteStates(); } + + SymbolTable *MutableInputSymbols() { + return GetImpl()->MutableInputSymbols(); + } + + SymbolTable *MutableOutputSymbols() { + return GetImpl()->MutableOutputSymbols(); + } + + int64 NumStates() const { return GetImpl()->NumStates(); } + + bool ReserveArcs(int64 s, size_t n) { return GetImpl()->ReserveArcs(s, n); } + + void ReserveStates(int64 s) { GetImpl()->ReserveStates(s); } + + static MutableFstClass *Read(const string &fname, bool convert = false); + + void SetInputSymbols(SymbolTable *isyms) { + GetImpl()->SetInputSymbols(isyms); + } + + bool SetFinal(int64 s, const WeightClass &weight) { + if (!WeightTypesMatch(weight, "SetFinal")) return false; + return GetImpl()->SetFinal(s, weight); + } + + void SetOutputSymbols(SymbolTable *osyms) { + GetImpl()->SetOutputSymbols(osyms); + } + + void SetProperties(uint64 props, uint64 mask) { + GetImpl()->SetProperties(props, mask); + } + + bool SetStart(int64 s) { return GetImpl()->SetStart(s); } + + template + explicit MutableFstClass(const MutableFst &fst) : FstClass(fst) {} + + // These methods are required by IO registration. + + template + static FstClassImplBase *Convert(const FstClass &other) { + FSTERROR() << "Doesn't make sense to convert any class to type " + << "MutableFstClass"; + return nullptr; + } + + template + static FstClassImplBase *Create() { + FSTERROR() << "Doesn't make sense to create a MutableFstClass with a " + << "particular arc type"; + return nullptr; + } + + template + MutableFst *GetMutableFst() { + Fst *fst = const_cast *>(this->GetFst()); + MutableFst *mfst = static_cast *>(fst); + return mfst; + } + + template + static MutableFstClass *Read(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr> mfst(MutableFst::Read(stream, opts)); + return mfst ? new MutableFstClass(*mfst) : nullptr; + } + + protected: + explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) {} +}; + +class VectorFstClass : public MutableFstClass { + public: + explicit VectorFstClass(FstClassImplBase *impl) : MutableFstClass(impl) {} + + explicit VectorFstClass(const FstClass &other); + + explicit VectorFstClass(const string &arc_type); + + static VectorFstClass *Read(const string &fname); + + template + static VectorFstClass *Read(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr> mfst(VectorFst::Read(stream, opts)); + return mfst ? new VectorFstClass(*mfst) : nullptr; + } + + template + explicit VectorFstClass(const VectorFst &fst) : MutableFstClass(fst) {} + + template + static FstClassImplBase *Convert(const FstClass &other) { + return new FstClassImpl(new VectorFst(*other.GetFst()), + true); + } + + template + static FstClassImplBase *Create() { + return new FstClassImpl(new VectorFst(), true); + } +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_FST_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/fstscript-decl.h b/projects/llm_framework/include/fst/script/fstscript-decl.h new file mode 100644 index 00000000..294d0159 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fstscript-decl.h @@ -0,0 +1,32 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Forward declarations for the FST and FST script classes. + +#ifndef FST_SCRIPT_FSTSCRIPT_DECL_H_ +#define FST_SCRIPT_FSTSCRIPT_DECL_H_ + +#include + +namespace fst { +namespace script { + +class ArcClass; + +class ArcIteratorClass; +class MutableArcIteratorClass; + +class EncodeMapperClass; + +class FstClass; +class MutableFstClass; +class VectorFstClass; + +class StateIteratorClass; + +class WeightClass; + +} // namespace script +} // namespace fst; + +#endif // FST_SCRIPT_FSTSCRIPT_DECL_H_ diff --git a/projects/llm_framework/include/fst/script/fstscript.h b/projects/llm_framework/include/fst/script/fstscript.h new file mode 100644 index 00000000..45f16175 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fstscript.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// The FST script interface permits users to interact with FSTs without knowing +// their arc type. It does this by mapping compile-time polymorphism (in the +// form of a arc-templated FST types) onto a shared virtual interface. It also +// supports arc extension via a DSO interface. Due to the overhead of virtual +// dispatch and registered function lookups, the script API is somewhat slower +// then library API provided by types like StdVectorFst, but has the advantage +// that it is designed not to crash (and to provide useful debugging +// information) upon common user errors like passing invalid indices or +// attempting comparison of incompatible FSTs. It is used both by the FST +// binaries and the Python extension. +// +// This header includes all of the FST script functionality. + +#ifndef FST_SCRIPT_FSTSCRIPT_H_ +#define FST_SCRIPT_FSTSCRIPT_H_ + +// Major classes +#include +#include +#include +#include +#include +#include + +// Flag-to-enum parsers. +#include +// Templates like Operation<> and Apply<>. +#include + +// Operations. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// This class is necessary because registering each of the operations +// separately overfills the stack, as there's so many of them. +namespace fst { +namespace script { + +template +class AllFstOperationsRegisterer { + public: + AllFstOperationsRegisterer() { + RegisterBatch1(); + RegisterBatch2(); + } + + private: + void RegisterBatch1() { + REGISTER_FST_OPERATION(ArcSort, Arc, ArcSortArgs); + REGISTER_FST_OPERATION(Closure, Arc, ClosureArgs); + REGISTER_FST_OPERATION(CompileFstInternal, Arc, CompileFstArgs); + REGISTER_FST_OPERATION(Compose, Arc, ComposeArgs); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs1); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs2); + REGISTER_FST_OPERATION(Connect, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Convert, Arc, ConvertArgs); + REGISTER_FST_OPERATION(Decode, Arc, DecodeArgs1); + REGISTER_FST_OPERATION(Decode, Arc, DecodeArgs2); + REGISTER_FST_OPERATION(Determinize, Arc, DeterminizeArgs); + REGISTER_FST_OPERATION(Difference, Arc, DifferenceArgs); + REGISTER_FST_OPERATION(Disambiguate, Arc, DisambiguateArgs); + REGISTER_FST_OPERATION(DrawFst, Arc, FstDrawerArgs); + REGISTER_FST_OPERATION(Encode, Arc, EncodeArgs1); + REGISTER_FST_OPERATION(Encode, Arc, EncodeArgs2); + REGISTER_FST_OPERATION(EpsNormalize, Arc, EpsNormalizeArgs); + REGISTER_FST_OPERATION(Equal, Arc, EqualArgs); + REGISTER_FST_OPERATION(Equivalent, Arc, EquivalentArgs); + REGISTER_FST_OPERATION(PrintFstInfo, Arc, InfoArgs); + REGISTER_FST_OPERATION(GetFstInfo, Arc, GetInfoArgs); + REGISTER_FST_OPERATION(InitArcIteratorClass, Arc, + InitArcIteratorClassArgs); + REGISTER_FST_OPERATION(InitEncodeMapperClass, Arc, + InitEncodeMapperClassArgs); + REGISTER_FST_OPERATION(InitMutableArcIteratorClass, Arc, + InitMutableArcIteratorClassArgs); + REGISTER_FST_OPERATION(InitStateIteratorClass, Arc, + InitStateIteratorClassArgs); + } + + void RegisterBatch2() { + REGISTER_FST_OPERATION(Intersect, Arc, IntersectArgs); + REGISTER_FST_OPERATION(Invert, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Map, Arc, MapArgs); + REGISTER_FST_OPERATION(Minimize, Arc, MinimizeArgs); + REGISTER_FST_OPERATION(PrintFst, Arc, FstPrinterArgs); + REGISTER_FST_OPERATION(Project, Arc, ProjectArgs); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs1); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs2); + REGISTER_FST_OPERATION(Push, Arc, PushArgs1); + REGISTER_FST_OPERATION(Push, Arc, PushArgs2); + REGISTER_FST_OPERATION(RandEquivalent, Arc, RandEquivalentArgs); + REGISTER_FST_OPERATION(RandGen, Arc, RandGenArgs); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs1); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs2); + REGISTER_FST_OPERATION(Replace, Arc, ReplaceArgs); + REGISTER_FST_OPERATION(Reverse, Arc, ReverseArgs); + REGISTER_FST_OPERATION(Reweight, Arc, ReweightArgs); + REGISTER_FST_OPERATION(RmEpsilon, Arc, RmEpsilonArgs); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs1); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs2); + REGISTER_FST_OPERATION(ShortestPath, Arc, ShortestPathArgs); + REGISTER_FST_OPERATION(Synchronize, Arc, SynchronizeArgs); + REGISTER_FST_OPERATION(TopSort, Arc, TopSortArgs); + REGISTER_FST_OPERATION(Union, Arc, UnionArgs); + REGISTER_FST_OPERATION(Verify, Arc, VerifyArgs); + } +}; + +} // namespace script +} // namespace fst + +#define REGISTER_FST_OPERATIONS(Arc) \ + AllFstOperationsRegisterer register_all_fst_operations##Arc; + +#endif // FST_SCRIPT_FSTSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/script/getters.h b/projects/llm_framework/include/fst/script/getters.h new file mode 100644 index 00000000..5cd727e8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/getters.h @@ -0,0 +1,76 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Getters for converting command-line arguments into the appropriate enums +// or bitmasks, with the simplest ones defined as inline. + +#ifndef FST_SCRIPT_GETTERS_H_ +#define FST_SCRIPT_GETTERS_H_ + +#include + +#include // For ComposeFilter. +#include // For DeterminizeType. +#include // For kEncodeLabels (etc.). +#include // For EpsNormalizeType. +#include // For ProjectType. +#include // For kPushWeights (etc.). +#include // For QueueType. +#include // For ClosureType. +#include // For ArcSortType. +#include // For MapType. +#include // For RandArcSelection. + +#include + +namespace fst { +namespace script { + +bool GetArcSortType(const string &str, ArcSortType *sort_type); + +inline ClosureType GetClosureType(bool closure_plus) { + return closure_plus ? CLOSURE_PLUS : CLOSURE_STAR; +} + +bool GetComposeFilter(const string &str, ComposeFilter *compose_filter); + +bool GetDeterminizeType(const string &str, DeterminizeType *det_type); + +inline uint32 GetEncodeFlags(bool encode_labels, bool encode_weights) { + return (encode_labels ? kEncodeLabels : 0) | + (encode_weights ? kEncodeWeights : 0); +} + +inline EpsNormalizeType GetEpsNormalizeType(bool eps_norm_output) { + return eps_norm_output ? EPS_NORM_OUTPUT : EPS_NORM_INPUT; +} + +bool GetMapType(const string &str, MapType *map_type); + +inline ProjectType GetProjectType(bool project_output) { + return project_output ? PROJECT_OUTPUT : PROJECT_INPUT; +} + +inline uint32 GetPushFlags(bool push_weights, bool push_labels, + bool remove_total_weight, bool remove_common_affix) { + return ((push_weights ? kPushWeights : 0) | + (push_labels ? kPushLabels : 0) | + (remove_total_weight ? kPushRemoveTotalWeight : 0) | + (remove_common_affix ? kPushRemoveCommonAffix : 0)); +} + +bool GetQueueType(const string &str, QueueType *queue_type); + +bool GetRandArcSelection(const string &str, RandArcSelection *ras); + +bool GetReplaceLabelType(const string &str, bool epsilon_on_replace, + ReplaceLabelType *rlt); + +inline ReweightType GetReweightType(bool to_final) { + return to_final ? REWEIGHT_TO_FINAL : REWEIGHT_TO_INITIAL; +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/script/info-impl.h b/projects/llm_framework/include/fst/script/info-impl.h new file mode 100644 index 00000000..e8956498 --- /dev/null +++ b/projects/llm_framework/include/fst/script/info-impl.h @@ -0,0 +1,314 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute various information about FSTs, a helper class for +// fstinfo.cc. + +#ifndef FST_SCRIPT_INFO_IMPL_H_ +#define FST_SCRIPT_INFO_IMPL_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Compute various information about FSTs, helper class for fstinfo.cc. +// WARNING: Stand-alone use of this class is not recommended, most code +// should call directly the relevant library functions: Fst::NumStates, +// Fst::NumArcs, TestProperties, etc. +class FstInfo { + public: + FstInfo() {} + + // When info_type is "short" (or "auto" and not an ExpandedFst) then only + // minimal info is computed and can be requested. + template + FstInfo(const Fst &fst, bool test_properties, + const string &arc_filter_type = "any", + const string &info_type = "auto", bool verify = true) + : fst_type_(fst.Type()), + input_symbols_(fst.InputSymbols() ? fst.InputSymbols()->Name() + : "none"), + output_symbols_(fst.OutputSymbols() ? fst.OutputSymbols()->Name() + : "none"), + nstates_(0), + narcs_(0), + start_(kNoStateId), + nfinal_(0), + nepsilons_(0), + niepsilons_(0), + noepsilons_(0), + ilabel_mult_(0.0), + olabel_mult_(0.0), + naccess_(0), + ncoaccess_(0), + nconnect_(0), + ncc_(0), + nscc_(0), + input_match_type_(MATCH_NONE), + output_match_type_(MATCH_NONE), + input_lookahead_(false), + output_lookahead_(false), + properties_(0), + arc_filter_type_(arc_filter_type), + long_info_(true), + arc_type_(Arc::Type()) { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (info_type == "long") { + long_info_ = true; + } else if (info_type == "short") { + long_info_ = false; + } else if (info_type == "auto") { + long_info_ = fst.Properties(kExpanded, false); + } else { + FSTERROR() << "Bad info type: " << info_type; + return; + } + if (!long_info_) return; + // If the FST is not sane, we return. + if (verify && !Verify(fst)) { + FSTERROR() << "FstInfo: Verify: FST not well-formed"; + return; + } + start_ = fst.Start(); + properties_ = fst.Properties(kFstProperties, test_properties); + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + const auto s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) ++nfinal_; + std::map ilabel_count; + std::map olabel_count; + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + ++narcs_; + if (arc.ilabel == 0 && arc.olabel == 0) ++nepsilons_; + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + ++ilabel_count[arc.ilabel]; + ++olabel_count[arc.olabel]; + } + for (auto it = ilabel_count.begin(); it != ilabel_count.end(); ++it) { + ilabel_mult_ += it->second * it->second; + } + for (auto it = olabel_count.begin(); it != olabel_count.end(); ++it) { + olabel_mult_ += it->second * it->second; + } + } + if (narcs_ > 0) { + ilabel_mult_ /= narcs_; + olabel_mult_ /= narcs_; + } + { + std::vector cc; + CcVisitor cc_visitor(&cc); + FifoQueue fifo_queue; + if (arc_filter_type == "any") { + Visit(fst, &cc_visitor, &fifo_queue); + } else if (arc_filter_type == "epsilon") { + Visit(fst, &cc_visitor, &fifo_queue, EpsilonArcFilter()); + } else if (arc_filter_type == "iepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, InputEpsilonArcFilter()); + } else if (arc_filter_type == "oepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, OutputEpsilonArcFilter()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + for (StateId s = 0; s < cc.size(); ++s) { + if (cc[s] >= ncc_) ncc_ = cc[s] + 1; + } + } + { + std::vector scc; + std::vector access, coaccess; + uint64 props = 0; + SccVisitor scc_visitor(&scc, &access, &coaccess, &props); + if (arc_filter_type == "any") { + DfsVisit(fst, &scc_visitor); + } else if (arc_filter_type == "epsilon") { + DfsVisit(fst, &scc_visitor, EpsilonArcFilter()); + } else if (arc_filter_type == "iepsilon") { + DfsVisit(fst, &scc_visitor, InputEpsilonArcFilter()); + } else if (arc_filter_type == "oepsilon") { + DfsVisit(fst, &scc_visitor, OutputEpsilonArcFilter()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + for (StateId s = 0; s < scc.size(); ++s) { + if (access[s]) ++naccess_; + if (coaccess[s]) ++ncoaccess_; + if (access[s] && coaccess[s]) ++nconnect_; + if (scc[s] >= nscc_) nscc_ = scc[s] + 1; + } + } + LookAheadMatcher> imatcher(fst, MATCH_INPUT); + input_match_type_ = imatcher.Type(test_properties); + input_lookahead_ = imatcher.Flags() & kInputLookAheadMatcher; + LookAheadMatcher> omatcher(fst, MATCH_OUTPUT); + output_match_type_ = omatcher.Type(test_properties); + output_lookahead_ = omatcher.Flags() & kOutputLookAheadMatcher; + } + + // Short info. + + const string &FstType() const { return fst_type_; } + + const string &ArcType() const { return arc_type_; } + + const string &InputSymbols() const { return input_symbols_; } + + const string &OutputSymbols() const { return output_symbols_; } + + bool LongInfo() const { return long_info_; } + + const string &ArcFilterType() const { return arc_filter_type_; } + + // Long info. + + MatchType InputMatchType() const { + CheckLong(); + return input_match_type_; + } + + MatchType OutputMatchType() const { + CheckLong(); + return output_match_type_; + } + + bool InputLookAhead() const { + CheckLong(); + return input_lookahead_; + } + + bool OutputLookAhead() const { + CheckLong(); + return output_lookahead_; + } + + int64 NumStates() const { + CheckLong(); + return nstates_; + } + + size_t NumArcs() const { + CheckLong(); + return narcs_; + } + + int64 Start() const { + CheckLong(); + return start_; + } + + size_t NumFinal() const { + CheckLong(); + return nfinal_; + } + + size_t NumEpsilons() const { + CheckLong(); + return nepsilons_; + } + + size_t NumInputEpsilons() const { + CheckLong(); + return niepsilons_; + } + + size_t NumOutputEpsilons() const { + CheckLong(); + return noepsilons_; + } + + double InputLabelMultiplicity() const { + CheckLong(); + return ilabel_mult_; + } + + double OutputLabelMultiplicity() const { + CheckLong(); + return olabel_mult_; + } + + size_t NumAccessible() const { + CheckLong(); + return naccess_; + } + + size_t NumCoAccessible() const { + CheckLong(); + return ncoaccess_; + } + + size_t NumConnected() const { + CheckLong(); + return nconnect_; + } + + size_t NumCc() const { + CheckLong(); + return ncc_; + } + + size_t NumScc() const { + CheckLong(); + return nscc_; + } + + uint64 Properties() const { + CheckLong(); + return properties_; + } + + private: + void CheckLong() const { + if (!long_info_) + FSTERROR() << "FstInfo: Method only available with long info signature"; + } + + string fst_type_; + string input_symbols_; + string output_symbols_; + int64 nstates_; + size_t narcs_; + int64 start_; + size_t nfinal_; + size_t nepsilons_; + size_t niepsilons_; + size_t noepsilons_; + double ilabel_mult_; + double olabel_mult_; + size_t naccess_; + size_t ncoaccess_; + size_t nconnect_; + size_t ncc_; + size_t nscc_; + MatchType input_match_type_; + MatchType output_match_type_; + bool input_lookahead_; + bool output_lookahead_; + uint64 properties_; + string arc_filter_type_; + bool long_info_; + string arc_type_; +}; + +void PrintFstInfoImpl(const FstInfo &fstinfo, bool pipe = false); + +} // namespace fst + +#endif // FST_SCRIPT_INFO_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/info.h b/projects/llm_framework/include/fst/script/info.h new file mode 100644 index 00000000..039d06d8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/info.h @@ -0,0 +1,50 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INFO_H_ +#define FST_SCRIPT_INFO_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using InfoArgs = std::tuple; + +template +void PrintFstInfo(InfoArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + const FstInfo fstinfo(fst, std::get<1>(*args), std::get<2>(*args), + std::get<3>(*args), std::get<4>(*args)); + PrintFstInfoImpl(fstinfo, std::get<5>(*args)); + if (std::get<5>(*args)) fst.Write(""); +} + +void PrintFstInfo(const FstClass &f, bool test_properties, + const string &arc_filter, const string &info_type, bool pipe, + bool verify); + +using GetInfoArgs = std::tuple; + +template +void GetFstInfo(GetInfoArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + *(std::get<5>(*args)) = FstInfo(fst, std::get<1>(*args), std::get<2>(*args), + std::get<3>(*args), std::get<4>(*args)); +} + +void GetFstInfo(const FstClass &fst, bool test_properties, + const string &arc_filter, const string &info_type, bool verify, + FstInfo *info); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INFO_H_ diff --git a/projects/llm_framework/include/fst/script/intersect.h b/projects/llm_framework/include/fst/script/intersect.h new file mode 100644 index 00000000..229bd56f --- /dev/null +++ b/projects/llm_framework/include/fst/script/intersect.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INTERSECT_H_ +#define FST_SCRIPT_INTERSECT_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using IntersectArgs = std::tuple; + +template +void Intersect(IntersectArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Intersect(ifst1, ifst2, ofst, opts); +} + +void Intersect(const FstClass &ifst, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INTERSECT_H_ diff --git a/projects/llm_framework/include/fst/script/invert.h b/projects/llm_framework/include/fst/script/invert.h new file mode 100644 index 00000000..5bc31317 --- /dev/null +++ b/projects/llm_framework/include/fst/script/invert.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INVERT_H_ +#define FST_SCRIPT_INVERT_H_ + +#include +#include + +namespace fst { +namespace script { + +template +void Invert(MutableFstClass *fst) { + Invert(fst->GetMutableFst()); +} + +void Invert(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INVERT_H_ diff --git a/projects/llm_framework/include/fst/script/isomorphic.h b/projects/llm_framework/include/fst/script/isomorphic.h new file mode 100644 index 00000000..94ea77f9 --- /dev/null +++ b/projects/llm_framework/include/fst/script/isomorphic.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ISOMORPHIC_H_ +#define FST_SCRIPT_ISOMORPHIC_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using IsomorphicInnerArgs = std::tuple; + +using IsomorphicArgs = WithReturnValue; + +template +void Isomorphic(IsomorphicArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Isomorphic(fst1, fst2, std::get<2>(args->args)); +} + +bool Isomorphic(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/script/map.h b/projects/llm_framework/include/fst/script/map.h new file mode 100644 index 00000000..158d98aa --- /dev/null +++ b/projects/llm_framework/include/fst/script/map.h @@ -0,0 +1,158 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_MAP_H_ +#define FST_SCRIPT_MAP_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +template +Fst *ArcMap(const Fst &fst, + const M &mapper) { + using ToArc = typename M::ToArc; + auto *ofst = new VectorFst; + ArcMap(fst, ofst, mapper); + return ofst; +} + +template +Fst *StateMap(const Fst &fst, + const M &mapper) { + using ToArc = typename M::ToArc; + auto *ofst = new VectorFst; + StateMap(fst, ofst, mapper); + return ofst; +} + +enum MapType { + ARC_SUM_MAPPER, + ARC_UNIQUE_MAPPER, + IDENTITY_MAPPER, + INPUT_EPSILON_MAPPER, + INVERT_MAPPER, + OUTPUT_EPSILON_MAPPER, + PLUS_MAPPER, + POWER_MAPPER, + QUANTIZE_MAPPER, + RMWEIGHT_MAPPER, + SUPERFINAL_MAPPER, + TIMES_MAPPER, + TO_LOG_MAPPER, + TO_LOG64_MAPPER, + TO_STD_MAPPER +}; + +using MapInnerArgs = + std::tuple; + +using MapArgs = WithReturnValue; + +template +void Map(MapArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(args->args).GetFst()); + const auto map_type = std::get<1>(args->args); + switch (map_type) { + case ARC_SUM_MAPPER: { + std::unique_ptr> ofst(StateMap(ifst, ArcSumMapper(ifst))); + args->retval = new FstClass(*ofst); + return; + } + case ARC_UNIQUE_MAPPER: { + std::unique_ptr> ofst( + StateMap(ifst, ArcUniqueMapper(ifst))); + args->retval = new FstClass(*ofst); + return; + } + case IDENTITY_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, IdentityArcMapper())); + args->retval = new FstClass(*ofst); + return; + } + case INPUT_EPSILON_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, InputEpsilonMapper())); + args->retval = new FstClass(*ofst); + return; + } + case INVERT_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, InvertWeightMapper())); + args->retval = new FstClass(*ofst); + return; + } + case OUTPUT_EPSILON_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, OutputEpsilonMapper())); + args->retval = new FstClass(*ofst); + return; + } + case PLUS_MAPPER: { + const auto weight = *(std::get<4>(args->args).GetWeight()); + std::unique_ptr> ofst(ArcMap(ifst, PlusMapper(weight))); + args->retval = new FstClass(*ofst); + return; + } + case POWER_MAPPER: { + const auto power = std::get<3>(args->args); + std::unique_ptr> ofst(ArcMap(ifst, PowerMapper(power))); + args->retval = new FstClass(*ofst); + return; + } + case QUANTIZE_MAPPER: { + const auto delta = std::get<2>(args->args); + std::unique_ptr> ofst(ArcMap(ifst, QuantizeMapper(delta))); + args->retval = new FstClass(*ofst); + return; + } + case RMWEIGHT_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, RmWeightMapper())); + args->retval = new FstClass(*ofst); + return; + } + case SUPERFINAL_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, SuperFinalMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TIMES_MAPPER: { + const auto weight = *(std::get<4>(args->args).GetWeight()); + std::unique_ptr> ofst(ArcMap(ifst, TimesMapper(weight))); + args->retval = new FstClass(*ofst); + return; + } + case TO_LOG_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TO_LOG64_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TO_STD_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + } +} + +FstClass *Map(const FstClass &ifst, MapType map_type, float delta, double power, + const WeightClass &weight); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_MAP_H_ diff --git a/projects/llm_framework/include/fst/script/minimize.h b/projects/llm_framework/include/fst/script/minimize.h new file mode 100644 index 00000000..773e8ef1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/minimize.h @@ -0,0 +1,33 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_MINIMIZE_H_ +#define FST_SCRIPT_MINIMIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using MinimizeArgs = std::tuple; + +template +void Minimize(MinimizeArgs *args) { + MutableFst *ofst1 = std::get<0>(*args)->GetMutableFst(); + MutableFst *ofst2 = (std::get<1>(*args) ? + std::get<1>(*args)->GetMutableFst() : + nullptr); + Minimize(ofst1, ofst2, std::get<2>(*args), std::get<3>(*args)); +} + +void Minimize(MutableFstClass *ofst1, MutableFstClass *ofst2 = nullptr, + float delta = kShortestDelta, bool allow_nondet = false); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_MINIMIZE_H_ diff --git a/projects/llm_framework/include/fst/script/print-impl.h b/projects/llm_framework/include/fst/script/print-impl.h new file mode 100644 index 00000000..539c6d8f --- /dev/null +++ b/projects/llm_framework/include/fst/script/print-impl.h @@ -0,0 +1,132 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Stand-alone class to print out binary FSTs in the AT&T format, a helper +// class for fstprint.cc. + +#ifndef FST_SCRIPT_PRINT_IMPL_H_ +#define FST_SCRIPT_PRINT_IMPL_H_ + +#include +#include +#include + +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Print a binary FST in textual format (helper class for fstprint.cc). +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template +class FstPrinter { + public: + using StateId = typename Arc::StateId; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + FstPrinter(const Fst &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + bool show_weight_one, const string &field_separator, + const string &missing_symbol = "") + : fst_(fst), + isyms_(isyms), + osyms_(osyms), + ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), + ostrm_(nullptr), + show_weight_one_(show_weight_one), + sep_(field_separator), + missing_symbol_(missing_symbol) {} + + // Prints FST to an output stream. + void Print(std::ostream *ostrm, const string &dest) { + ostrm_ = ostrm; + dest_ = dest; + const auto start = fst_.Start(); + if (start == kNoStateId) return; + // Initial state first. + PrintState(start); + for (StateIterator> siter(fst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s != start) PrintState(s); + } + } + + private: + void PrintId(StateId id, const SymbolTable *syms, const char *name) const { + if (syms) { + string symbol = syms->Find(id); + if (symbol.empty()) { + if (missing_symbol_.empty()) { + FSTERROR() << "FstPrinter: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } else { + symbol = missing_symbol_; + } + } + *ostrm_ << symbol; + } else { + *ostrm_ << id; + } + } + + void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); } + + void PrintILabel(Label l) const { PrintId(l, isyms_, "arc input label"); } + + void PrintOLabel(Label l) const { PrintId(l, osyms_, "arc output label"); } + + void PrintState(StateId s) const { + bool output = false; + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + PrintStateId(s); + *ostrm_ << sep_; + PrintStateId(arc.nextstate); + *ostrm_ << sep_; + PrintILabel(arc.ilabel); + if (!accep_) { + *ostrm_ << sep_; + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || arc.weight != Weight::One()) + *ostrm_ << sep_ << arc.weight; + *ostrm_ << "\n"; + output = true; + } + const auto weight = fst_.Final(s); + if (weight != Weight::Zero() || !output) { + PrintStateId(s); + if (show_weight_one_ || weight != Weight::One()) { + *ostrm_ << sep_ << weight; + } + *ostrm_ << "\n"; + } + } + + const Fst &fst_; + const SymbolTable *isyms_; // ilabel symbol table. + const SymbolTable *osyms_; // olabel symbol table. + const SymbolTable *ssyms_; // slabel symbol table. + bool accep_; // Print as acceptor when possible? + std::ostream *ostrm_; // Text FST destination. + string dest_; // Text FST destination name. + bool show_weight_one_; // Print weights equal to Weight::One()? + string sep_; // Separator character between fields. + string missing_symbol_; // Symbol to print when lookup fails (default + // "" means raise error). + // + FstPrinter(const FstPrinter &) = delete; + FstPrinter &operator=(const FstPrinter &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_PRINT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/print.h b/projects/llm_framework/include/fst/script/print.h new file mode 100644 index 00000000..687606b3 --- /dev/null +++ b/projects/llm_framework/include/fst/script/print.h @@ -0,0 +1,79 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PRINT_H_ +#define FST_SCRIPT_PRINT_H_ + +#include + +#include +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstPrinterArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accept; + const bool show_weight_one; + std::ostream *ostrm; + const string &dest; + const string &sep; // NOLINT + const string &missing_symbol; + + FstPrinterArgs(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, + bool accept, bool show_weight_one, std::ostream *ostrm, + const string &dest, const string &sep, + const string &missing_sym = "") + : fst(fst), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accept(accept), + show_weight_one(show_weight_one), + ostrm(ostrm), + dest(dest), + sep(sep), + missing_symbol(missing_sym) {} +}; + +template +void PrintFst(FstPrinterArgs *args) { + const Fst &fst = *(args->fst.GetFst()); + FstPrinter fstprinter(fst, args->isyms, args->osyms, args->ssyms, + args->accept, args->show_weight_one, args->sep, + args->missing_symbol); + fstprinter.Print(args->ostrm, args->dest); +} + +void PrintFst(const FstClass &fst, std::ostream &ostrm, const string &dest, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accept, bool show_weight_one, + const string &missing_sym = ""); + +// The same, but with more sensible defaults. +template +void PrintFst(const Fst &fst, std::ostream &ostrm, const string &dest = "", + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + const SymbolTable *ssyms = nullptr) { + const string sep = FLAGS_fst_field_separator.substr(0, 1); + FstPrinter fstprinter(fst, isyms, osyms, ssyms, true, true, sep); + fstprinter.Print(&ostrm, dest); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PRINT_H_ diff --git a/projects/llm_framework/include/fst/script/project.h b/projects/llm_framework/include/fst/script/project.h new file mode 100644 index 00000000..13edeb1d --- /dev/null +++ b/projects/llm_framework/include/fst/script/project.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PROJECT_H_ +#define FST_SCRIPT_PROJECT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ProjectArgs = std::pair; + +template +void Project(ProjectArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Project(fst, std::get<1>(*args)); +} + +void Project(MutableFstClass *fst, ProjectType project_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PROJECT_H_ diff --git a/projects/llm_framework/include/fst/script/prune.h b/projects/llm_framework/include/fst/script/prune.h new file mode 100644 index 00000000..ed10b540 --- /dev/null +++ b/projects/llm_framework/include/fst/script/prune.h @@ -0,0 +1,51 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PRUNE_H_ +#define FST_SCRIPT_PRUNE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using PruneArgs1 = std::tuple; + +template +void Prune(PruneArgs1 *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto weight_threshold = *(std::get<2>(*args).GetWeight()); + Prune(ifst, ofst, weight_threshold, std::get<3>(*args), std::get<4>(*args)); +} + +using PruneArgs2 = std::tuple; + +template +void Prune(PruneArgs2 *args) { + using Weight = typename Arc::Weight; + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const auto weight_threshold = *(std::get<1>(*args).GetWeight()); + Prune(fst, weight_threshold, std::get<2>(*args), std::get<3>(*args)); +} + +void Prune(const FstClass &ifst, MutableFstClass *ofst, + const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + float delta = kDelta); + +void Prune(MutableFstClass *fst, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PRUNE_H_ diff --git a/projects/llm_framework/include/fst/script/push.h b/projects/llm_framework/include/fst/script/push.h new file mode 100644 index 00000000..018cd8f8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/push.h @@ -0,0 +1,53 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PUSH_H_ +#define FST_SCRIPT_PUSH_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using PushArgs1 = std::tuple; + +template +void Push(PushArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Push(fst, std::get<1>(*args), std::get<2>(*args), std::get<3>(*args)); +} + +using PushArgs2 = std::tuple; + +template +void Push(PushArgs2 *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + switch (std::get<3>(*args)) { + case REWEIGHT_TO_FINAL: { + Push(ifst, ofst, std::get<2>(*args), + std::get<4>(*args)); + return; + } + case REWEIGHT_TO_INITIAL: { + Push(ifst, ofst, std::get<2>(*args), + std::get<4>(*args)); + return; + } + } +} + +void Push(MutableFstClass *fst, ReweightType rew_type, float delta = kDelta, + bool remove_total_weight = false); + +void Push(const FstClass &ifst, MutableFstClass *ofst, uint32 flags, + ReweightType rew_type, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PUSH_H_ diff --git a/projects/llm_framework/include/fst/script/randequivalent.h b/projects/llm_framework/include/fst/script/randequivalent.h new file mode 100644 index 00000000..945f8a06 --- /dev/null +++ b/projects/llm_framework/include/fst/script/randequivalent.h @@ -0,0 +1,67 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RANDEQUIVALENT_H_ +#define FST_SCRIPT_RANDEQUIVALENT_H_ + +#include +#include + +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +using RandEquivalentInnerArgs = std::tuple &>; + +using RandEquivalentArgs = WithReturnValue; + +template +void RandEquivalent(RandEquivalentArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + const auto seed = std::get<4>(args->args); + const auto &opts = std::get<5>(args->args); + switch (opts.selector) { + case UNIFORM_ARC_SELECTOR: { + const UniformArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + case FAST_LOG_PROB_ARC_SELECTOR: { + const FastLogProbArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + case LOG_PROB_ARC_SELECTOR: { + const LogProbArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + } +} + +bool RandEquivalent(const FstClass &fst1, const FstClass &fst2, int32 npath = 1, + float delta = kDelta, time_t seed = time(nullptr), + const RandGenOptions &opts = + RandGenOptions(UNIFORM_ARC_SELECTOR)); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RANDEQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/script/randgen.h b/projects/llm_framework/include/fst/script/randgen.h new file mode 100644 index 00000000..5ce79d01 --- /dev/null +++ b/projects/llm_framework/include/fst/script/randgen.h @@ -0,0 +1,63 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RANDGEN_H_ +#define FST_SCRIPT_RANDGEN_H_ + +#include + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using RandGenArgs = std::tuple &>; + +template +void RandGen(RandGenArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const time_t seed = std::get<2>(*args); + const auto &opts = std::get<3>(*args); + switch (opts.selector) { + case UNIFORM_ARC_SELECTOR: { + const UniformArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + case FAST_LOG_PROB_ARC_SELECTOR: { + const FastLogProbArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + case LOG_PROB_ARC_SELECTOR: { + const LogProbArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + } +} + +void RandGen(const FstClass &ifst, MutableFstClass *ofst, + time_t seed = time(nullptr), + const RandGenOptions &opts = + RandGenOptions(UNIFORM_ARC_SELECTOR)); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RANDGEN_H_ diff --git a/projects/llm_framework/include/fst/script/register.h b/projects/llm_framework/include/fst/script/register.h new file mode 100644 index 00000000..d66e7ade --- /dev/null +++ b/projects/llm_framework/include/fst/script/register.h @@ -0,0 +1,99 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REGISTER_H_ +#define FST_SCRIPT_REGISTER_H_ + +#include +#include + +#include +#include +#include + +// Holds methods and classes responsible for maintaining +// the register for FstClass arc types. + +namespace fst { +namespace script { + +// Registers for reading and converting various kinds of FST classes. + +// This class definition is to avoid a nested class definition inside the +// IORegistration struct. + +template +struct FstClassRegEntry { + Reader reader; + Creator creator; + Converter converter; + + FstClassRegEntry(Reader r, Creator cr, Converter co) + : reader(r), creator(cr), converter(co) {} + + FstClassRegEntry() + : reader(nullptr), creator(nullptr), converter(nullptr) {} +}; + +template +class FstClassIORegister + : public GenericRegister, + FstClassIORegister> { + public: + Reader GetReader(const string &arc_type) const { + return this->GetEntry(arc_type).reader; + } + + Creator GetCreator(const string &arc_type) const { + return this->GetEntry(arc_type).creator; + } + + Converter GetConverter(const string &arc_type) const { + return this->GetEntry(arc_type).converter; + } + + protected: + string ConvertKeyToSoFilename(const string &key) const final { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-arc.so"; + } +}; + +// Struct containing everything needed to register a particular type +// of FST class (e.g., a plain FstClass, or a MutableFstClass, etc.). +template +struct IORegistration { + using Reader = FstClassType *(*)(std::istream &stream, + const FstReadOptions &opts); + + using Creator = FstClassImplBase *(*)(); + + using Converter = FstClassImplBase *(*)(const FstClass &other); + + using Entry = FstClassRegEntry; + + // FST class Register. + using Register = FstClassIORegister; + + // FST class Register-er. + using Registerer = + GenericRegisterer>; +}; + +#define REGISTER_FST_CLASS(Class, Arc) \ + static IORegistration::Registerer Class##_##Arc##_registerer( \ + Arc::Type(), \ + IORegistration::Entry(Class::Read, Class::Create, \ + Class::Convert)) + +#define REGISTER_FST_CLASSES(Arc) \ + REGISTER_FST_CLASS(FstClass, Arc); \ + REGISTER_FST_CLASS(MutableFstClass, Arc); \ + REGISTER_FST_CLASS(VectorFstClass, Arc); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/script/relabel.h b/projects/llm_framework/include/fst/script/relabel.h new file mode 100644 index 00000000..74443490 --- /dev/null +++ b/projects/llm_framework/include/fst/script/relabel.h @@ -0,0 +1,64 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RELABEL_H_ +#define FST_SCRIPT_RELABEL_H_ + +#include +#include +#include +#include + +#include +#include + +namespace fst { +namespace script { + +using RelabelArgs1 = std::tuple; + +template +void Relabel(RelabelArgs1 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + Relabel(ofst, std::get<1>(*args), std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args), std::get<5>(*args), std::get<6>(*args), + std::get<7>(*args), std::get<8>(*args)); +} + +using LabelPair = std::pair; + +using RelabelArgs2 = std::tuple &, + const std::vector &>; + +template +void Relabel(RelabelArgs2 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + using LabelPair = std::pair; + // In case the MutableFstClass::Label is not the same as Arc::Label, + // make a copy. + std::vector typed_ipairs(std::get<1>(*args).size()); + std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), + typed_ipairs.begin()); + std::vector typed_opairs(std::get<2>(*args).size()); + std::copy(std::get<2>(*args).begin(), std::get<2>(*args).end(), + typed_opairs.begin()); + Relabel(ofst, typed_ipairs, typed_opairs); +} + +void Relabel(MutableFstClass *ofst, + const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, + const string &unknown_isymbol, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + const string &unknown_osymbol, bool attach_new_osymbols); + +void Relabel(MutableFstClass *ofst, const std::vector &ipairs, + const std::vector &opairs); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RELABEL_H_ diff --git a/projects/llm_framework/include/fst/script/replace.h b/projects/llm_framework/include/fst/script/replace.h new file mode 100644 index 00000000..926ece24 --- /dev/null +++ b/projects/llm_framework/include/fst/script/replace.h @@ -0,0 +1,72 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REPLACE_H_ +#define FST_SCRIPT_REPLACE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { +namespace script { + +struct ReplaceOptions { + const int64 root; // Root rule for expansion. + const ReplaceLabelType call_label_type; // How to label call arc. + const ReplaceLabelType return_label_type; // How to label return arc. + const int64 return_label; // Specifies return arc label. + + explicit ReplaceOptions(int64 root, + ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT, + ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER, + int64 return_label = 0) + : root(root), + call_label_type(call_label_type), + return_label_type(return_label_type), + return_label(return_label) {} +}; + +using LabelFstClassPair = std::pair; + +using ReplaceArgs = std::tuple &, + MutableFstClass *, const ReplaceOptions &>; + +template +void Replace(ReplaceArgs *args) { + using LabelFstPair = std::pair *>; + // Now that we know the arc type, we construct a vector of + // std::pair that the real Replace will use. + const auto &untyped_pairs = std::get<0>(*args); + std::vector typed_pairs; + typed_pairs.reserve(untyped_pairs.size()); + for (const auto &untyped_pair : untyped_pairs) { + typed_pairs.emplace_back(untyped_pair.first, // Converts label. + untyped_pair.second->GetFst()); + } + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + ReplaceFstOptions typed_opts(opts.root, opts.call_label_type, + opts.return_label_type, opts.return_label); + ReplaceFst rfst(typed_pairs, typed_opts); + // Checks for cyclic dependencies before attempting expansion. + if (rfst.CyclicDependencies()) { + FSTERROR() << "Replace: Cyclic dependencies detected; cannot expand"; + ofst->SetProperties(kError, kError); + return; + } + typed_opts.gc = true; // Caching options to speed up batch copy. + typed_opts.gc_limit = 0; + *ofst = rfst; +} + +void Replace(const std::vector &pairs, + MutableFstClass *ofst, const ReplaceOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REPLACE_H_ diff --git a/projects/llm_framework/include/fst/script/reverse.h b/projects/llm_framework/include/fst/script/reverse.h new file mode 100644 index 00000000..badd96b5 --- /dev/null +++ b/projects/llm_framework/include/fst/script/reverse.h @@ -0,0 +1,30 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REVERSE_H_ +#define FST_SCRIPT_REVERSE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ReverseArgs = std::tuple; + +template +void Reverse(ReverseArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Reverse(ifst, ofst, std::get<2>(*args)); +} + +void Reverse(const FstClass &ifst, MutableFstClass *ofst, + bool require_superinitial = true); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REVERSE_H_ diff --git a/projects/llm_framework/include/fst/script/reweight.h b/projects/llm_framework/include/fst/script/reweight.h new file mode 100644 index 00000000..3893ad85 --- /dev/null +++ b/projects/llm_framework/include/fst/script/reweight.h @@ -0,0 +1,37 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REWEIGHT_H_ +#define FST_SCRIPT_REWEIGHT_H_ + +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +using ReweightArgs = std::tuple &, ReweightType>; + +template +void Reweight(ReweightArgs *args) { + using Weight = typename Arc::Weight; + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const std::vector &potentials = std::get<1>(*args); + std::vector typed_potentials; + internal::CopyWeights(potentials, &typed_potentials); + Reweight(fst, typed_potentials, std::get<2>(*args)); +} + +void Reweight(MutableFstClass *fst, const std::vector &potentials, + ReweightType reweight_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REWEIGHT_H_ diff --git a/projects/llm_framework/include/fst/script/rmepsilon.h b/projects/llm_framework/include/fst/script/rmepsilon.h new file mode 100644 index 00000000..42986c85 --- /dev/null +++ b/projects/llm_framework/include/fst/script/rmepsilon.h @@ -0,0 +1,109 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RMEPSILON_H_ +#define FST_SCRIPT_RMEPSILON_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +struct RmEpsilonOptions : public ShortestDistanceOptions { + const bool connect; + const WeightClass &weight_threshold; + const int64 state_threshold; + + RmEpsilonOptions(QueueType queue_type, bool connect, + const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, float delta = kDelta) + : ShortestDistanceOptions(queue_type, EPSILON_ARC_FILTER, kNoStateId, + delta), + connect(connect), + weight_threshold(weight_threshold), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Code to implement switching on queue types. + +template +void RmEpsilon(MutableFst *fst, + std::vector *distance, + const RmEpsilonOptions &opts, Queue *queue) { + using Weight = typename Arc::Weight; + const fst::RmEpsilonOptions ropts( + queue, opts.delta, opts.connect, + *opts.weight_threshold.GetWeight(), opts.state_threshold); + RmEpsilon(fst, distance, ropts); +} + +template +void RmEpsilon(MutableFst *fst, const RmEpsilonOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + AutoQueue queue(*fst, &distance, EpsilonArcFilter()); + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case FIFO_QUEUE: { + FifoQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case LIFO_QUEUE: { + LifoQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case SHORTEST_FIRST_QUEUE: { + NaturalShortestFirstQueue queue(distance); + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case STATE_ORDER_QUEUE: { + StateOrderQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case TOP_ORDER_QUEUE: { + TopOrderQueue queue(*fst, EpsilonArcFilter()); + internal::RmEpsilon(fst, &distance, opts, &queue); + return; + } + default: { + FSTERROR() << "RmEpsilon: Unknown queue type: " << opts.queue_type; + fst->SetProperties(kError, kError); + return; + } + } +} + +} // namespace internal + +using RmEpsilonArgs = std::pair; + +template +void RmEpsilon(RmEpsilonArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const auto &opts = std::get<1>(*args); + internal::RmEpsilon(fst, opts); +} + +void RmEpsilon(MutableFstClass *fst, const RmEpsilonOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RMEPSILON_H_ diff --git a/projects/llm_framework/include/fst/script/script-impl.h b/projects/llm_framework/include/fst/script/script-impl.h new file mode 100644 index 00000000..33c2853a --- /dev/null +++ b/projects/llm_framework/include/fst/script/script-impl.h @@ -0,0 +1,211 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This file defines the registration mechanism for new operations. +// These operations are designed to enable scripts to work with FST classes +// at a high level. +// +// If you have a new arc type and want these operations to work with FSTs +// with that arc type, see below for the registration steps +// you must take. +// +// These methods are only recommended for use in high-level scripting +// applications. Most users should use the lower-level templated versions +// corresponding to these. +// +// If you have a new arc type you'd like these operations to work with, +// use the REGISTER_FST_OPERATIONS macro defined in fstscript.h. +// +// If you have a custom operation you'd like to define, you need four +// components. In the following, assume you want to create a new operation +// with the signature +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst); +// +// You need: +// +// 1) A way to bundle the args that your new Foo operation will take, as +// a single struct. The template structs in arg-packs.h provide a handy +// way to do this. In Foo's case, that might look like this: +// +// using FooArgs = std::pair; +// +// Note: this package of args is going to be passed by non-const pointer. +// +// 2) A function template that is able to perform Foo, given the args and +// arc type. Yours might look like this: +// +// template +// void Foo(FooArgs *args) { +// // Pulls out the actual, arc-templated FSTs. +// const Fst &ifst = std::get<0>(*args).GetFst(); +// MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); +// // Actually perform Foo on ifst and ofst. +// } +// +// 3) a client-facing function for your operation. This would look like +// the following: +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst) { +// // Check that the arc types of the FSTs match +// if (!ArcTypesMatch(ifst, *ofst, "Foo")) return; +// // package the args +// FooArgs args(ifst, ofst); +// // Finally, call the operation +// Apply>("Foo", ifst->ArcType(), &args); +// } +// +// The Apply<> function template takes care of the link between 2 and 3, +// provided you also have: +// +// 4) A registration for your new operation, on the arc types you care about. +// This can be provided easily by the REGISTER_FST_OPERATION macro in +// operations.h: +// +// REGISTER_FST_OPERATION(Foo, StdArc, FooArgs); +// REGISTER_FST_OPERATION(Foo, MyArc, FooArgs); +// // .. etc +// +// +// That's it! Now when you call Foo(const FstClass &, MutableFstClass *), +// it dispatches (in #3) via the Apply<> function to the correct +// instantiation of the template function in #2. +// + +#ifndef FST_SCRIPT_SCRIPT_IMPL_H_ +#define FST_SCRIPT_SCRIPT_IMPL_H_ + +// This file contains general-purpose templates which are used in the +// implementation of the operations. + +#include +#include + +#include +#include + +#include + +namespace fst { +namespace script { + +enum RandArcSelection { + UNIFORM_ARC_SELECTOR, + LOG_PROB_ARC_SELECTOR, + FAST_LOG_PROB_ARC_SELECTOR +}; + +// A generic register for operations with various kinds of signatures. +// Needed since every function signature requires a new registration class. +// The std::pair is understood to be the operation name and arc +// type; subclasses (or typedefs) need only provide the operation signature. + +template +class GenericOperationRegister + : public GenericRegister, OperationSignature, + GenericOperationRegister> { + public: + void RegisterOperation(const string &operation_name, const string &arc_type, + OperationSignature op) { + this->SetEntry(std::make_pair(operation_name, arc_type), op); + } + + OperationSignature GetOperation(const string &operation_name, + const string &arc_type) { + return this->GetEntry(std::make_pair(operation_name, arc_type)); + } + + protected: + string ConvertKeyToSoFilename( + const std::pair &key) const final { + // Uses the old-style FST for now. + string legal_type(key.second); // The arc type. + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-arc.so"; + } +}; + +// Operation package: everything you need to register a new type of operation. +// The ArgPack should be the type that's passed into each wrapped function; +// for instance, it might be a struct containing all the args. It's always +// passed by pointer, so const members should be used to enforce constness where +// it's needed. Return values should be implemented as a member of ArgPack as +// well. + +template +struct Operation { + using ArgPack = Args; + + using OpType = void (*)(ArgPack *args); + + // The register (hash) type. + using Register = GenericOperationRegister; + + // The register-er type + using Registerer = GenericRegisterer; +}; + +// Macro for registering new types of operations. + +#define REGISTER_FST_OPERATION(Op, Arc, ArgPack) \ + static fst::script::Operation::Registerer \ + arc_dispatched_operation_##ArgPack##Op##Arc##_registerer \ + (std::make_pair(#Op, Arc::Type()), Op) + +// Template function to apply an operation by name. + +template +void Apply(const string &op_name, const string &arc_type, + typename OpReg::ArgPack *args) { + const auto op = OpReg::Register::GetRegister()->GetOperation(op_name, + arc_type); + if (!op) { + FSTERROR() << "No operation found for " << op_name << " on " + << "arc type " << arc_type; + return; + } + op(args); +} + +namespace internal { + +// Helper that logs to ERROR if the arc types of m and n don't match, +// assuming that both m and n implement .ArcType(). The op_name argument is +// used to construct the error message. +template +bool ArcTypesMatch(const M &m, const N &n, const string &op_name) { + if (m.ArcType() != n.ArcType()) { + FSTERROR() << "Arguments with non-matching arc types passed to " + << op_name << ":\t" << m.ArcType() << " and " << n.ArcType(); + return false; + } + return true; +} + +// From untyped to typed weights. +template +void CopyWeights(const std::vector &weights, + std::vector *typed_weights) { + typed_weights->clear(); + typed_weights->reserve(weights.size()); + for (const auto &weight : weights) { + typed_weights->push_back(*weight.GetWeight()); + } +} + +// From typed to untyped weights. +template +void CopyWeights(const std::vector &typed_weights, + std::vector *weights) { + weights->clear(); + weights->reserve(typed_weights.size()); + for (const auto &typed_weight : typed_weights) { + weights->emplace_back(typed_weight); + } +} + +} // namespace internal +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SCRIPT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/shortest-distance.h b/projects/llm_framework/include/fst/script/shortest-distance.h new file mode 100644 index 00000000..a44a6c9b --- /dev/null +++ b/projects/llm_framework/include/fst/script/shortest-distance.h @@ -0,0 +1,214 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_ +#define FST_SCRIPT_SHORTEST_DISTANCE_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +enum ArcFilterType { + ANY_ARC_FILTER, + EPSILON_ARC_FILTER, + INPUT_EPSILON_ARC_FILTER, + OUTPUT_EPSILON_ARC_FILTER +}; + +struct ShortestDistanceOptions { + const QueueType queue_type; + const ArcFilterType arc_filter_type; + const int64 source; + const float delta; + + ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type, + int64 source, float delta) + : queue_type(queue_type), + arc_filter_type(arc_filter_type), + source(source), + delta(delta) {} +}; + +namespace internal { + +// Code to implement switching on queue and arc filter types. + +template +struct QueueConstructor { + using Weight = typename Arc::Weight; + + static Queue *Construct(const Fst &, const std::vector *) { + return new Queue(); + } +}; + +// Specializations to support queues with different constructors. + +template +struct QueueConstructor, ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // template + static AutoQueue *Construct(const Fst &fst, + const std::vector *distance) { + return new AutoQueue(fst, distance, ArcFilter()); + } +}; + +template +struct QueueConstructor< + Arc, NaturalShortestFirstQueue, + ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + static NaturalShortestFirstQueue *Construct( + const Fst &, const std::vector *distance) { + return new NaturalShortestFirstQueue(*distance); + } +}; + +template +struct QueueConstructor, ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + static TopOrderQueue *Construct(const Fst &fst, + const std::vector *) { + return new TopOrderQueue(fst, ArcFilter()); + } +}; + +template +void ShortestDistance(const Fst &fst, + std::vector *distance, + const ShortestDistanceOptions &opts) { + std::unique_ptr queue( + QueueConstructor::Construct(fst, distance)); + const fst::ShortestDistanceOptions sopts( + queue.get(), ArcFilter(), opts.source, opts.delta); + ShortestDistance(fst, distance, sopts); +} + +template +void ShortestDistance(const Fst &fst, + std::vector *distance, + const ShortestDistanceOptions &opts) { + switch (opts.arc_filter_type) { + case ANY_ARC_FILTER: { + ShortestDistance>(fst, distance, opts); + return; + } + case EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, opts); + return; + } + case INPUT_EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, + opts); + return; + } + case OUTPUT_EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, + opts); + return; + } + default: { + FSTERROR() << "ShortestDistance: Unknown arc filter type: " + << opts.arc_filter_type; + distance->clear(); + distance->resize(1, Arc::Weight::NoWeight()); + return; + } + } +} + +} // namespace internal + +using ShortestDistanceArgs1 = + std::tuple *, + const ShortestDistanceOptions &>; + +template +void ShortestDistance(ShortestDistanceArgs1 *args) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + const Fst &fst = *(std::get<0>(*args).GetFst()); + const auto &opts = std::get<2>(*args); + std::vector typed_distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case FIFO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case LIFO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case SHORTEST_FIRST_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + case STATE_ORDER_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + case TOP_ORDER_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + default: { + FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type; + typed_distance.clear(); + typed_distance.resize(1, Arc::Weight::NoWeight()); + break; + } + } + internal::CopyWeights(typed_distance, std::get<1>(*args)); +} + +using ShortestDistanceArgs2 = + std::tuple *, bool, double>; + +template +void ShortestDistance(ShortestDistanceArgs2 *args) { + using Weight = typename Arc::Weight; + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::vector typed_distance; + ShortestDistance(fst, &typed_distance, std::get<2>(*args), + std::get<3>(*args)); + internal::CopyWeights(typed_distance, std::get<1>(*args)); +} + +void ShortestDistance(const FstClass &fst, std::vector *distance, + const ShortestDistanceOptions &opts); + +void ShortestDistance(const FstClass &ifst, std::vector *distance, + bool reverse = false, + double delta = fst::kShortestDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SHORTEST_DISTANCE_H_ diff --git a/projects/llm_framework/include/fst/script/shortest-path.h b/projects/llm_framework/include/fst/script/shortest-path.h new file mode 100644 index 00000000..86bc88da --- /dev/null +++ b/projects/llm_framework/include/fst/script/shortest-path.h @@ -0,0 +1,116 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SHORTEST_PATH_H_ +#define FST_SCRIPT_SHORTEST_PATH_H_ + +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +// Slightly simplified interface: `has_distance` and `first_path` are disabled. + +struct ShortestPathOptions : public ShortestDistanceOptions { + const int32 nshortest; + const bool unique; + const WeightClass &weight_threshold; + const int64 state_threshold; + + ShortestPathOptions(QueueType queue_type, int32 nshortest, bool unique, + float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId) + : ShortestDistanceOptions(queue_type, ANY_ARC_FILTER, kNoStateId, delta), + nshortest(nshortest), + unique(unique), + weight_threshold(weight_threshold), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Code to implement switching on queue types. + +template +void ShortestPath(const Fst &ifst, MutableFst *ofst, + std::vector *distance, + const ShortestPathOptions &opts) { + using ArcFilter = AnyArcFilter; + using Weight = typename Arc::Weight; + const std::unique_ptr queue( + QueueConstructor::Construct(ifst, distance)); + const fst::ShortestPathOptions sopts( + queue.get(), ArcFilter(), opts.nshortest, opts.unique, + /* has_distance=*/false, opts.delta, /* first_path=*/false, + *opts.weight_threshold.GetWeight(), opts.state_threshold); + ShortestPath(ifst, ofst, distance, sopts); +} + +template +void ShortestPath(const Fst &ifst, MutableFst *ofst, + const ShortestPathOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case FIFO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case LIFO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case SHORTEST_FIRST_QUEUE: { + ShortestPath>(ifst, ofst, + &distance, + opts); + return; + } + case STATE_ORDER_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case TOP_ORDER_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + default: { + FSTERROR() << "ShortestPath: Unknown queue type: " + << opts.queue_type; + ofst->SetProperties(kError, kError); + return; + } + } +} + +} // namespace internal + +using ShortestPathArgs = std::tuple; + +template +void ShortestPath(ShortestPathArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const ShortestPathOptions &opts = std::get<2>(*args); + internal::ShortestPath(ifst, ofst, opts); +} + +void ShortestPath(const FstClass &ifst, MutableFstClass *ofst, + const ShortestPathOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SHORTEST_PATH_H_ diff --git a/projects/llm_framework/include/fst/script/stateiterator-class.h b/projects/llm_framework/include/fst/script/stateiterator-class.h new file mode 100644 index 00000000..f6fddfe6 --- /dev/null +++ b/projects/llm_framework/include/fst/script/stateiterator-class.h @@ -0,0 +1,85 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_STATEITERATOR_CLASS_H_ +#define FST_SCRIPT_STATEITERATOR_CLASS_H_ + +#include + +#include +#include + +// Scripting API support for StateIterator. + +namespace fst { +namespace script { + +// Virtual interface implemented by each concrete StateIteratorImpl. +class StateIteratorImplBase { + public: + virtual bool Done() const = 0; + virtual int64 Value() const = 0; + virtual void Next() = 0; + virtual void Reset() = 0; + virtual ~StateIteratorImplBase() {} +}; + +// Templated implementation. +template +class StateIteratorClassImpl : public StateIteratorImplBase { + public: + explicit StateIteratorClassImpl(const Fst &fst) : siter_(fst) {} + + bool Done() const final { return siter_.Done(); } + + int64 Value() const final { return siter_.Value(); } + + void Next() final { siter_.Next(); } + + void Reset() final { siter_.Reset(); } + + ~StateIteratorClassImpl() override {} + + private: + StateIterator> siter_; +}; + +class StateIteratorClass; + +using InitStateIteratorClassArgs = + std::pair; + +// Untemplated user-facing class holding a templated pimpl. +class StateIteratorClass { + public: + explicit StateIteratorClass(const FstClass &fst); + + template + explicit StateIteratorClass(const Fst &fst) + : impl_(new StateIteratorClassImpl(fst)) {} + + bool Done() const { return impl_->Done(); } + + int64 Value() const { return impl_->Value(); } + + void Next() { impl_->Next(); } + + void Reset() { impl_->Reset(); } + + template + friend void InitStateIteratorClass(InitStateIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitStateIteratorClass(InitStateIteratorClassArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::get<1>(*args)->impl_.reset(new StateIteratorClassImpl(fst)); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_STATEITERATOR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/synchronize.h b/projects/llm_framework/include/fst/script/synchronize.h new file mode 100644 index 00000000..01df151a --- /dev/null +++ b/projects/llm_framework/include/fst/script/synchronize.h @@ -0,0 +1,29 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SYNCHRONIZE_H_ +#define FST_SCRIPT_SYNCHRONIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using SynchronizeArgs = std::pair; + +template +void Synchronize(SynchronizeArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Synchronize(ifst, ofst); +} + +void Synchronize(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SYNCHRONIZE_H_ diff --git a/projects/llm_framework/include/fst/script/text-io.h b/projects/llm_framework/include/fst/script/text-io.h new file mode 100644 index 00000000..464bf885 --- /dev/null +++ b/projects/llm_framework/include/fst/script/text-io.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utilities for reading and writing textual strings representing states, +// labels, and weights and files specifying label-label pairs and potentials +// (state-weight pairs). + +#ifndef FST_SCRIPT_TEXT_IO_H__ +#define FST_SCRIPT_TEXT_IO_H__ + +#include +#include + +#include + +namespace fst { +namespace script { + +bool ReadPotentials(const string &weight_type, const string &filename, + std::vector *potentials); + +bool WritePotentials(const string &filename, + const std::vector &potentials); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TEXT_IO_H__ diff --git a/projects/llm_framework/include/fst/script/topsort.h b/projects/llm_framework/include/fst/script/topsort.h new file mode 100644 index 00000000..fb6738d7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/topsort.h @@ -0,0 +1,26 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_TOPSORT_H_ +#define FST_SCRIPT_TOPSORT_H_ + +#include +#include +#include + +namespace fst { +namespace script { + +using TopSortArgs = WithReturnValue; + +template +void TopSort(TopSortArgs *args) { + args->retval = TopSort(args->args->GetMutableFst()); +} + +bool TopSort(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TOPSORT_H_ diff --git a/projects/llm_framework/include/fst/script/union.h b/projects/llm_framework/include/fst/script/union.h new file mode 100644 index 00000000..9493e2b1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/union.h @@ -0,0 +1,29 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_UNION_H_ +#define FST_SCRIPT_UNION_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using UnionArgs = std::pair; + +template +void Union(UnionArgs *args) { + MutableFst *fst1 = std::get<0>(*args)->GetMutableFst(); + const Fst &fst2 = *(std::get<1>(*args).GetFst()); + Union(fst1, fst2); +} + +void Union(MutableFstClass *fst1, const FstClass &fst2); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_UNION_H_ diff --git a/projects/llm_framework/include/fst/script/verify.h b/projects/llm_framework/include/fst/script/verify.h new file mode 100644 index 00000000..52f58641 --- /dev/null +++ b/projects/llm_framework/include/fst/script/verify.h @@ -0,0 +1,27 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_VERIFY_H_ +#define FST_SCRIPT_VERIFY_H_ + +#include +#include +#include + +namespace fst { +namespace script { + +using VerifyArgs = WithReturnValue; + +template +void Verify(VerifyArgs *args) { + const Fst &fst = *(args->args.GetFst()); + args->retval = Verify(fst); +} + +bool Verify(const FstClass &fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_VERIFY_H_ diff --git a/projects/llm_framework/include/fst/script/weight-class.h b/projects/llm_framework/include/fst/script/weight-class.h new file mode 100644 index 00000000..6dadf92c --- /dev/null +++ b/projects/llm_framework/include/fst/script/weight-class.h @@ -0,0 +1,235 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Represents a generic weight in an FST; that is, represents a specific type +// of weight underneath while hiding that type from a client. + +#ifndef FST_SCRIPT_WEIGHT_CLASS_H_ +#define FST_SCRIPT_WEIGHT_CLASS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +class WeightImplBase { + public: + virtual WeightImplBase *Copy() const = 0; + virtual void Print(std::ostream *o) const = 0; + virtual const string &Type() const = 0; + virtual string ToString() const = 0; + virtual bool Member() const = 0; + virtual bool operator==(const WeightImplBase &other) const = 0; + virtual bool operator!=(const WeightImplBase &other) const = 0; + virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &PowerEq(size_t n) = 0; + virtual ~WeightImplBase() {} +}; + +template +class WeightClassImpl : public WeightImplBase { + public: + explicit WeightClassImpl(const W &weight) : weight_(weight) {} + + WeightClassImpl *Copy() const final { + return new WeightClassImpl(weight_); + } + + const string &Type() const final { return W::Type(); } + + void Print(std::ostream *ostrm) const final { *ostrm << weight_; } + + string ToString() const final { + string str; + WeightToStr(weight_, &str); + return str; + } + + bool Member() const final { return weight_.Member(); } + + bool operator==(const WeightImplBase &other) const final { + const auto *typed_other = static_cast *>(&other); + return weight_ == typed_other->weight_; + } + + bool operator!=(const WeightImplBase &other) const final { + return !(*this == other); + } + + WeightClassImpl &PlusEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Plus(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &TimesEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Times(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &DivideEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Divide(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &PowerEq(size_t n) final { + weight_ = Power(weight_, n); + return *this; + } + + W *GetImpl() { return &weight_; } + + private: + W weight_; +}; + + +class WeightClass { + public: + WeightClass() = default; + + template + explicit WeightClass(const W &weight) + : impl_(new WeightClassImpl(weight)) {} + + template + explicit WeightClass(const WeightClassImpl &impl) + : impl_(new WeightClassImpl(impl)) {} + + WeightClass(const string &weight_type, const string &weight_str); + + WeightClass(const WeightClass &other) + : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {} + + WeightClass &operator=(const WeightClass &other) { + impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr); + return *this; + } + + static constexpr const char *__ZERO__ = "__ZERO__"; // NOLINT + + static WeightClass Zero(const string &weight_type); + + static constexpr const char *__ONE__ = "__ONE__"; // NOLINT + + static WeightClass One(const string &weight_type); + + static constexpr const char *__NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT + + static WeightClass NoWeight(const string &weight_type); + + template + const W *GetWeight() const { + if (W::Type() != impl_->Type()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + string ToString() const { return (impl_) ? impl_->ToString() : "none"; } + + const string &Type() const { + if (impl_) return impl_->Type(); + static const string *const no_type = new string("none"); + return *no_type; + } + + bool Member() const { return impl_ && impl_->Member(); } + + bool WeightTypesMatch(const WeightClass &other, const string &op_name) const; + + friend bool operator==(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Power(const WeightClass &w, size_t n); + + private: + const WeightImplBase *GetImpl() const { return impl_.get(); } + + WeightImplBase *GetImpl() { return impl_.get(); } + + std::unique_ptr impl_; + + friend std::ostream &operator<<(std::ostream &o, const WeightClass &c); +}; + +bool operator==(const WeightClass &lhs, const WeightClass &rhs); + +bool operator!=(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Times(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Power(const WeightClass &w, size_t n); + +std::ostream &operator<<(std::ostream &o, const WeightClass &c); + +// Registration for generic weight types. + +using StrToWeightImplBaseT = WeightImplBase *(*)(const string &str, + const string &src, + size_t nline); + +template +WeightImplBase *StrToWeightImplBase(const string &str, const string &src, + size_t nline) { + if (str == WeightClass::__ZERO__) + return new WeightClassImpl(W::Zero()); + else if (str == WeightClass::__ONE__) + return new WeightClassImpl(W::One()); + else if (str == WeightClass::__NOWEIGHT__) + return new WeightClassImpl(W::NoWeight()); + return new WeightClassImpl(StrToWeight(str, src, nline)); +} + +class WeightClassRegister : public GenericRegister { + protected: + string ConvertKeyToSoFilename(const string &key) const final { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + ".so"; + } +}; + +using WeightClassRegisterer = GenericRegisterer; + +// Internal version; needs to be called by wrapper in order for macro args to +// expand. +#define REGISTER_FST_WEIGHT__(Weight, line) \ + static WeightClassRegisterer weight_registerer##_##line( \ + Weight::Type(), StrToWeightImplBase) + +// This layer is where __FILE__ and __LINE__ are expanded. +#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \ + REGISTER_FST_WEIGHT__(Weight, line) + +// Macro for registering new weight types; clients call this. +#define REGISTER_FST_WEIGHT(Weight) \ + REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__) + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_WEIGHT_CLASS_H_ diff --git a/projects/llm_framework/include/fst/set-weight.h b/projects/llm_framework/include/fst/set-weight.h new file mode 100644 index 00000000..dd665f0d --- /dev/null +++ b/projects/llm_framework/include/fst/set-weight.h @@ -0,0 +1,618 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Weights consisting of sets (of integral Labels) and +// associated semiring operation definitions using intersect +// and union. + +#ifndef FST_SET_WEIGHT_H_ +#define FST_SET_WEIGHT_H_ + +#include + +#include +#include +#include +#include + +#include +#include + + +namespace fst { + +constexpr int kSetEmpty = 0; // Label for the empty set. +constexpr int kSetUniv = -1; // Label for the universal set. +constexpr int kSetBad = -2; // Label for a non-set. +constexpr char kSetSeparator = '_'; // Label separator in sets. + +// Determines whether to use (intersect, union) or (union, intersect) +// as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a +// restricted version of (intersect, union) that requires summed +// arguments to be equal (or an error is signalled), useful for +// algorithms that require a unique labelled path weight. SET_BOOLEAN +// treats all non-Zero() elements as equivalent (with Zero() == +// UnivSet()), useful for algorithms that don't really depend on the +// detailed sets. +enum SetType { SET_INTERSECT_UNION = 0, + SET_UNION_INTERSECT = 1, + SET_INTERSECT_UNION_RESTRICT = 2, + SET_BOOLEAN = 3 }; + +template +class SetWeightIterator; + +// Set semiring of integral labels. +template +class SetWeight { + public: + using Label = Label_; + using ReverseWeight = SetWeight; + using Iterator = SetWeightIterator; + friend class SetWeightIterator; + // Allow type-converting copy and move constructors private access. + template + friend class SetWeight; + + SetWeight() {} + + // Input should be positive, sorted and unique. + template + SetWeight(const Iterator &begin, const Iterator &end) { + for (auto iter = begin; iter != end; ++iter) PushBack(*iter); + } + + // Input should be positive. (Non-positive value has + // special internal meaning w.r.t. integral constants above.) + explicit SetWeight(Label label) { PushBack(label); } + + template + explicit SetWeight(const SetWeight &w) + : first_(w.first_), rest_(w.rest_) {} + + template + explicit SetWeight(SetWeight &&w) + : first_(w.first_), rest_(std::move(w.rest_)) { w.Clear(); } + + template + SetWeight &operator=(const SetWeight &w) { + first_ = w.first_; + rest_ = w.rest_; + return *this; + } + + template + SetWeight &operator=(SetWeight &&w) { + first_ = w.first_; + rest_ = std::move(w.rest_); + w.Clear(); + return *this; + } + + static const SetWeight &Zero() { + return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet(); + } + + static const SetWeight &One() { + return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet(); + } + + static const SetWeight &NoWeight() { + static const auto *const no_weight = new SetWeight(Label(kSetBad)); + return *no_weight; + } + + static const string &Type() { + static const string *const type = new string( + S == SET_UNION_INTERSECT + ? "union_intersect_set" + : (S == SET_INTERSECT_UNION + ? "intersect_union_set" + : (S == SET_INTERSECT_UNION_RESTRICT + ? "restricted_set_intersect_union" + : "boolean_set"))); + return *type; + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + SetWeight Quantize(float delta = kDelta) const { return *this; } + + ReverseWeight Reverse() const; + + static constexpr uint64 Properties() { + return kIdempotent | kLeftSemiring | kRightSemiring | kCommutative; + } + + // These operations combined with the SetWeightIterator + // provide the access and mutation of the set internal elements. + + // The empty set. + static const SetWeight &EmptySet() { + static const auto *const empty = new SetWeight(Label(kSetEmpty)); + return *empty; + } + + // The univeral set. + static const SetWeight &UnivSet() { + static const auto *const univ = new SetWeight(Label(kSetUniv)); + return *univ; + } + + // Clear existing SetWeight. + void Clear() { + first_ = kSetEmpty; + rest_.clear(); + } + + size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; } + + Label Back() { + if (rest_.empty()) { + return first_; + } else { + return rest_.back(); + } + } + + // Caller must add in sort order and be unique (or error signalled). + // Input should also be positive. Non-positive value (for the first + // push) has special internal meaning w.r.t. integral constants above. + void PushBack(Label label) { + if (first_ == kSetEmpty) { + first_ = label; + } else { + if (label <= Back() || label <= 0) { + FSTERROR() << "SetWeight: labels must be positive, added" + << " in sort order and be unique."; + rest_.push_back(Label(kSetBad)); + } + rest_.push_back(label); + } + } + + private: + Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty). + std::list &fst); +// +// // Required copy constructor that allows updating FST argument; +// // pass only if relevant and changed. +// StateMapper(const StateMapper &mapper, const Fst *fst = 0); +// +// // Specifies initial state of result. +// B::StateId Start() const; +// // Specifies state's final weight in result. +// B::Weight Final(B::StateId state) const; +// +// // These methods iterate through a state's arcs in result. +// +// // Specifies state to iterate over. +// void SetState(B::StateId state); +// +// // End of arcs? +// bool Done() const; +// +// // Current arc. +// const B &Value() const; +// +// // Advances to next arc (when !Done) +// void Next(); +// +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// +// // This specifies the known properties of an FST mapped by this +// // mapper. It takes as argument the input FST's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// We include a various state map versions below. One dimension of variation is +// whether the mapping mutates its input, writes to a new result FST, or is an +// on-the-fly Fst. Another dimension is how we pass the mapper. We allow passing +// the mapper by pointer for cases that we need to change the state of the +// user's mapper. We also include map versions that pass the mapper by value or +// const reference when this suffices. + +// Maps an arc type A using a mapper function object C, passed by pointer. This +// version modifies the input FST. +template +void StateMap(MutableFst *fst, C *mapper) { + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetOutputSymbols(nullptr); + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kFstProperties, false); + fst->SetStart(mapper->Start()); + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + mapper->SetState(state); + fst->DeleteArcs(state); + for (; !mapper->Done(); mapper->Next()) { + fst->AddArc(state, mapper->Value()); + } + fst->SetFinal(state, mapper->Final(state)); + } + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + +// Maps an arc type A using a mapper function object C, passed by value. +// This version modifies the input FST. +template +void StateMap(MutableFst *fst, C mapper) { + StateMap(fst, &mapper); +} + +// Maps an arc type A to an arc type B using mapper functor C, passed by +// pointer. This version writes to an output FST. +template +void StateMap(const Fst &ifst, MutableFst *ofst, C *mapper) { + ofst->DeleteStates(); + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetInputSymbols(ifst.InputSymbols()); + } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetOutputSymbols(nullptr); + } + const auto iprops = ifst.Properties(kCopyProperties, false); + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + // Adds all states. + if (ifst.Properties(kExpanded, false)) ofst->ReserveStates(CountStates(ifst)); + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + ofst->AddState(); + } + ofst->SetStart(mapper->Start()); + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + mapper->SetState(state); + for (; !mapper->Done(); mapper->Next()) { + ofst->AddArc(state, mapper->Value()); + } + ofst->SetFinal(state, mapper->Final(state)); + } + const auto oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper functor object C, passed by +// value. This version writes to an output FST. +template +void StateMap(const Fst &ifst, MutableFst *ofst, C mapper) { + StateMap(ifst, ofst, &mapper); +} + +using StateMapFstOptions = CacheOptions; + +template +class StateMapFst; + +// Facade around StateIteratorBase inheriting from StateIteratorBase. +template +class StateMapStateIteratorBase : public StateIteratorBase { + public: + using Arc = B; + using StateId = typename Arc::StateId; + + explicit StateMapStateIteratorBase(StateIteratorBase *base) + : base_(base) {} + + bool Done() const final { return base_->Done(); } + + StateId Value() const final { return base_->Value(); } + + void Next() final { base_->Next(); } + + void Reset() final { base_->Reset(); } + + private: + std::unique_ptr> base_; + + StateMapStateIteratorBase() = delete; +}; + +namespace internal { + +// Implementation of delayed StateMapFst. +template +class StateMapFstImpl : public CacheImpl { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::PushArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + StateMapFstImpl(const Fst &fst, const C &mapper, + const StateMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(new C(mapper, fst_.get())), + own_mapper_(true) { + Init(); + } + + StateMapFstImpl(const Fst &fst, C *mapper, const StateMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false) { + Init(); + } + + StateMapFstImpl(const StateMapFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_, fst_.get())), + own_mapper_(true) { + Init(); + } + + ~StateMapFstImpl() override { + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) SetStart(mapper_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId state) { + if (!HasFinal(state)) SetFinal(state, mapper_->Final(state)); + return CacheImpl::Final(state); + } + + size_t NumArcs(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumArcs(state); + } + + size_t NumInputEpsilons(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumInputEpsilons(state); + } + + size_t NumOutputEpsilons(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumOutputEpsilons(state); + } + + void InitStateIterator(StateIteratorData *datb) const { + StateIteratorData data; + fst_->InitStateIterator(&data); + datb->base = data.base ? new StateMapStateIteratorBase(data.base) + : nullptr; + datb->nstates = data.nstates; + } + + void InitArcIterator(StateId state, ArcIteratorData *data) { + if (!HasArcs(state)) Expand(state); + CacheImpl::InitArcIterator(state, data); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void Expand(StateId state) { + // Adds exiting arcs. + for (mapper_->SetState(state); !mapper_->Done(); mapper_->Next()) { + PushArc(state, mapper_->Value()); + } + SetArcs(state); + } + + const Fst *GetFst() const { return fst_.get(); } + + private: + void Init() { + SetType("statemap"); + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetInputSymbols(fst_->InputSymbols()); + } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetInputSymbols(nullptr); + } + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetOutputSymbols(fst_->OutputSymbols()); + } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetOutputSymbols(nullptr); + } + const auto props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + } + + std::unique_ptr> fst_; + C *mapper_; + bool own_mapper_; +}; + +} // namespace internal + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed FST. +template +class StateMapFst : public ImplToFst> { + public: + friend class ArcIterator>; + + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::StateMapFstImpl; + + StateMapFst(const Fst &fst, const C &mapper, + const StateMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + StateMapFst(const Fst &fst, C *mapper, const StateMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + StateMapFst(const Fst &fst, const C &mapper) + : ImplToFst( + std::make_shared(fst, mapper, StateMapFstOptions())) {} + + StateMapFst(const Fst &fst, C *mapper) + : ImplToFst( + std::make_shared(fst, mapper, StateMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + StateMapFst(const StateMapFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. + StateMapFst *Copy(bool safe = false) const override { + return new StateMapFst(*this, safe); + } + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId state, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(state, data); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + private: + StateMapFst &operator=(const StateMapFst &) = delete; +}; + +// Specialization for StateMapFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename A::StateId; + + ArcIterator(const StateMapFst &fst, StateId state) + : CacheArcIterator>(fst.GetMutableImpl(), state) { + if (!fst.GetImpl()->HasArcs(state)) fst.GetMutableImpl()->Expand(state); + } +}; + +// Utility mappers. + +// Mapper that returns its input. +template +class IdentityStateMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit IdentityStateMapper(const Fst &fst) : fst_(fst) {} + + // Allows updating FST argument; pass only if changed. + IdentityStateMapper(const IdentityStateMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + aiter_.reset(new ArcIterator>(fst_, state)); + } + + bool Done() const { return aiter_->Done(); } + + const Arc &Value() const { return aiter_->Value(); } + + void Next() { aiter_->Next(); } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { return props; } + + private: + const Fst &fst_; + std::unique_ptr>> aiter_; +}; + +template +class ArcSumMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ArcSumMapper(const Fst &fst) : fst_(fst), i_(0) {} + + // Allows updating FST argument; pass only if changed. + ArcSumMapper(const ArcSumMapper &mapper, const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(state)); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + // First sorts the exiting arcs by input label, output label and destination + // state and then sums weights of arcs with the same input label, output + // label, and destination state. + std::sort(arcs_.begin(), arcs_.end(), comp_); + size_t narcs = 0; + for (const auto &arc : arcs_) { + if (narcs > 0 && equal_(arc, arcs_[narcs - 1])) { + arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, arc.weight); + } else { + arcs_[narcs] = arc; + ++narcs; + } + } + arcs_.resize(narcs); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & kDeleteArcsProperties & + kWeightInvariantProperties; + } + + private: + struct Compare { + bool operator()(const Arc &x, const Arc &y) const { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const Arc &x, const Arc &y) const { + return (x.ilabel == y.ilabel && x.olabel == y.olabel && + x.nextstate == y.nextstate); + } + }; + + const Fst &fst_; + Compare comp_; + Equal equal_; + std::vector arcs_; + ssize_t i_; // Current arc position. + + ArcSumMapper &operator=(const ArcSumMapper &) = delete; +}; + +template +class ArcUniqueMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ArcUniqueMapper(const Fst &fst) : fst_(fst), i_(0) {} + + // Allows updating FST argument; pass only if changed. + ArcUniqueMapper(const ArcUniqueMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(state)); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + // First sorts the exiting arcs by input label, output label and destination + // state and then uniques identical arcs. + std::sort(arcs_.begin(), arcs_.end(), comp_); + arcs_.erase(std::unique(arcs_.begin(), arcs_.end(), equal_), arcs_.end()); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & kDeleteArcsProperties; + } + + private: + struct Compare { + bool operator()(const Arc &x, const Arc &y) const { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const Arc &x, const Arc &y) const { + return (x.ilabel == y.ilabel && x.olabel == y.olabel && + x.nextstate == y.nextstate && x.weight == y.weight); + } + }; + + const Fst &fst_; + Compare comp_; + Equal equal_; + std::vector arcs_; + size_t i_; // Current arc position. + + ArcUniqueMapper &operator=(const ArcUniqueMapper &) = delete; +}; + +// Useful aliases when using StdArc. + +using StdArcSumMapper = ArcSumMapper; + +using StdArcUniqueMapper = ArcUniqueMapper; + +} // namespace fst + +#endif // FST_STATE_MAP_H_ diff --git a/projects/llm_framework/include/fst/state-reachable.h b/projects/llm_framework/include/fst/state-reachable.h new file mode 100644 index 00000000..36b5559a --- /dev/null +++ b/projects/llm_framework/include/fst/state-reachable.h @@ -0,0 +1,224 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to determine whether a given (final) state can be reached from some +// other given state. + +#ifndef FST_STATE_REACHABLE_H_ +#define FST_STATE_REACHABLE_H_ + +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// Computes the (final) states reachable from a given state in an FST. After +// this visitor has been called, a final state f can be reached from a state +// s iff (*isets)[s].Member(state2index[f]) is true, where (*isets[s]) is a +// set of half-open inteval of final state indices and state2index[f] maps from +// a final state to its index. If state2index is empty, it is filled-in with +// suitable indices. If it is non-empty, those indices are used; in this case, +// the final states must have out-degree 0. +template > +class IntervalReachVisitor { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Index = I; + using ISet = S; + using Interval = typename ISet::Interval; + + IntervalReachVisitor(const Fst &fst, std::vector *isets, + std::vector *state2index) + : fst_(fst), + isets_(isets), + state2index_(state2index), + index_(state2index->empty() ? 1 : -1), + error_(false) { + isets_->clear(); + } + + void InitVisit(const Fst &) { error_ = false; } + + bool InitState(StateId s, StateId r) { + while (isets_->size() <= s) isets_->push_back(S()); + while (state2index_->size() <= s) state2index_->push_back(-1); + if (fst_.Final(s) != Weight::Zero()) { + // Create tree interval. + auto *intervals = (*isets_)[s].MutableIntervals(); + if (index_ < 0) { // Uses state2index_ map to set index. + if (fst_.NumArcs(s) > 0) { + FSTERROR() << "IntervalReachVisitor: state2index map must be empty " + << "for this FST"; + error_ = true; + return false; + } + const auto index = (*state2index_)[s]; + if (index < 0) { + FSTERROR() << "IntervalReachVisitor: state2index map incomplete"; + error_ = true; + return false; + } + intervals->push_back(Interval(index, index + 1)); + } else { // Use pre-order index. + intervals->push_back(Interval(index_, index_ + 1)); + (*state2index_)[s] = index_++; + } + } + return true; + } + + constexpr bool TreeArc(StateId, const Arc &) const { return true; } + + bool BackArc(StateId s, const Arc &arc) { + FSTERROR() << "IntervalReachVisitor: Cyclic input"; + error_ = true; + return false; + } + + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + // Non-tree interval. + (*isets_)[s].Union((*isets_)[arc.nextstate]); + return true; + } + + void FinishState(StateId s, StateId p, const Arc *) { + if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) { + auto *intervals = (*isets_)[s].MutableIntervals(); + (*intervals)[0].end = index_; // Updates tree interval end. + } + (*isets_)[s].Normalize(); + if (p != kNoStateId) { + (*isets_)[p].Union((*isets_)[s]); // Propagates intervals to parent. + } + } + + void FinishVisit() {} + + bool Error() const { return error_; } + + private: + const Fst &fst_; + std::vector *isets_; + std::vector *state2index_; + Index index_; + bool error_; +}; + +// Tests reachability of final states from a given state. To test for +// reachability from a state s, first do SetState(s). Then a final state f can +// be reached from state s of FST iff Reach(f) is true. The input can be cyclic, +// but no cycle may contain a final state. +template > +class StateReachable { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Index = I; + using ISet = S; + using Interval = typename ISet::Interval; + + explicit StateReachable(const Fst &fst) : error_(false) { + if (fst.Properties(kAcyclic, true)) { + AcyclicStateReachable(fst); + } else { + CyclicStateReachable(fst); + } + } + + explicit StateReachable(const StateReachable &reachable) { + FSTERROR() << "Copy constructor for state reachable class " + << "not implemented."; + error_ = true; + } + + // Sets current state. + void SetState(StateId s) { s_ = s; } + + // Can reach this final state from current state? + bool Reach(StateId s) { + if (s >= state2index_.size()) return false; + const auto i = state2index_[s]; + if (i < 0) { + FSTERROR() << "StateReachable: State non-final: " << s; + error_ = true; + return false; + } + return isets_[s_].Member(i); + } + + // Access to the state-to-index mapping. Unassigned states have index -1. + std::vector &State2Index() { return state2index_; } + + // Access to the interval sets. These specify the reachability to the final + // states as intervals of the final state indices. + const std::vector &IntervalSets() { return isets_; } + + bool Error() const { return error_; } + + private: + void AcyclicStateReachable(const Fst &fst) { + IntervalReachVisitor reach_visitor(fst, &isets_, + &state2index_); + DfsVisit(fst, &reach_visitor); + if (reach_visitor.Error()) error_ = true; + } + + void CyclicStateReachable(const Fst &fst) { + // Finds state reachability on the acyclic condensation FST. + VectorFst cfst; + std::vector scc; + Condense(fst, &cfst, &scc); + StateReachable reachable(cfst); + if (reachable.Error()) { + error_ = true; + return; + } + // Gets the number of states per SCC. + std::vector nscc; + for (StateId s = 0; s < scc.size(); ++s) { + const auto c = scc[s]; + while (c >= nscc.size()) nscc.push_back(0); + ++nscc[c]; + } + // Constructs the interval sets and state index mapping for the original + // FST from the condensation FST. + state2index_.resize(scc.size(), -1); + isets_.resize(scc.size()); + for (StateId s = 0; s < scc.size(); ++s) { + const auto c = scc[s]; + isets_[s] = reachable.IntervalSets()[c]; + state2index_[s] = reachable.State2Index()[c]; + // Checks that each final state in an input FST is not contained in a + // cycle (i.e., not in a non-trivial SCC). + if (cfst.Final(c) != Weight::Zero() && nscc[c] > 1) { + FSTERROR() << "StateReachable: Final state contained in a cycle"; + error_ = true; + return; + } + } + } + + StateId s_; // Current state. + std::vector isets_; // Interval sets per state. + std::vector state2index_; // Finds index for a final state. + bool error_; + + StateReachable &operator=(const StateReachable &) = delete; +}; + +} // namespace fst + +#endif // FST_STATE_REACHABLE_H_ diff --git a/projects/llm_framework/include/fst/state-table.h b/projects/llm_framework/include/fst/state-table.h new file mode 100644 index 00000000..a5067592 --- /dev/null +++ b/projects/llm_framework/include/fst/state-table.h @@ -0,0 +1,494 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for representing the mapping between state tuples and state IDs. + +#ifndef FST_STATE_TABLE_H_ +#define FST_STATE_TABLE_H_ + +#include +#include +#include + +#include + +#include +#include +#include + + +namespace fst { + +// State tables determine the bijective mapping between state tuples (e.g., in +// composition, triples of two FST states and a composition filter state) and +// their corresponding state IDs. They are classes, templated on state tuples, +// with the following interface: +// +// template +// class StateTable { +// public: +// using StateTuple = T; +// +// // Required constructors. +// StateTable(); +// +// StateTable(const StateTable &); +// +// // Looks up state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &tuple); +// +// // Looks up state tuple by state ID. +// const StateTuple &Tuple(StateId s) const; +// +// // # of stored tuples. +// StateId Size() const; +// }; +// +// A state tuple has the form: +// +// template +// struct StateTuple { +// using StateId = S; +// +// // Required constructors. +// +// StateTuple(); +// +// StateTuple(const StateTuple &tuple); +// }; + +// An implementation using a hash map for the tuple to state ID mapping. The +// state tuple T must support operator==. +template +class HashStateTable : public HashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using HashBiTable::FindId; + using HashBiTable::FindEntry; + using HashBiTable::Size; + + HashStateTable() : HashBiTable() {} + + explicit HashStateTable(size_t table_size) + : HashBiTable(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a hash map for the tuple to state ID mapping. The +// state tuple T must support operator==. +template +class CompactHashStateTable + : public CompactHashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using CompactHashBiTable::FindId; + using CompactHashBiTable::FindEntry; + using CompactHashBiTable::Size; + + CompactHashStateTable() : CompactHashBiTable() {} + + explicit CompactHashStateTable(size_t table_size) + : CompactHashBiTable(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a vector for the tuple to state mapping. It is +// passed a fingerprint functor that should fingerprint tuples uniquely to an +// integer that can used as a vector index. Normally, VectorStateTable +// constructs the fingerprint functor. Alternately, the user can pass this +// object, in which case the table takes ownership. +template +class VectorStateTable : public VectorBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using VectorBiTable::FindId; + using VectorBiTable::FindEntry; + using VectorBiTable::Size; + using VectorBiTable::Fingerprint; + + explicit VectorStateTable(FP *fingerprint = nullptr, size_t table_size = 0) + : VectorBiTable(fingerprint, table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a vector and a compact hash table. The selection +// functor returns true for tuples to be hashed in the vector. The fingerprint +// functor should fingerprint tuples uniquely to an integer that can be used as +// a vector index. A hash functor is used when hashing tuples into the compact +// hash table. +template +class VectorHashStateTable + : public VectorHashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using VectorHashBiTable::FindId; + using VectorHashBiTable::FindEntry; + using VectorHashBiTable::Size; + using VectorHashBiTable::Selector; + using VectorHashBiTable::Fingerprint; + using VectorHashBiTable::Hash; + + VectorHashStateTable(Select *select, FP *fingerprint, H *hash, + size_t vector_size = 0, size_t tuple_size = 0) + : VectorHashBiTable( + select, fingerprint, hash, vector_size, tuple_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a hash map to map from tuples to state IDs. This +// version permits erasing of states. The state tuple's default constructor +// must produce a tuple that will never be seen and the table must suppor +// operator==. +template +class ErasableStateTable : public ErasableBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using ErasableBiTable::FindId; + using ErasableBiTable::FindEntry; + using ErasableBiTable::Size; + using ErasableBiTable::Erase; + + ErasableStateTable() : ErasableBiTable() {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// The composition state table has the form: +// +// template +// class ComposeStateTable { +// public: +// using StateId = typename Arc::StateId; +// +// // Required constructors. +// +// ComposeStateTable(const Fst &fst1, const Fst &fst2); +// ComposeStateTable(const ComposeStateTable &table); +// +// // Looks up a state ID by tuple, adding it if doesn't exist. +// StateId FindState(const StateTuple &tuple); +// +// // Looks up a tuple by state ID. +// const ComposeStateTuple &Tuple(StateId s) const; +// +// // The number of of stored tuples. +// StateId Size() const; +// +// // Return true if error was encountered. +// bool Error() const; +// }; +// +// The following interface is used to represent the composition state. +// +// template +// class CompositionStateTuple { +// public: +// using StateId = typename StateId; +// using FS = FilterState; +// +// // Required constructors. +// StateTuple(); +// StateTuple(StateId s1, StateId s2, const FilterState &fs); +// +// StateId StateId1() const; +// StateId StateId2() const; +// +// FilterState GetFilterState() const; +// +// std::pair StatePair() const; +// +// size_t Hash() const; +// +// friend bool operator==(const StateTuple& x, const StateTuple &y); +// } +// +template +class DefaultComposeStateTuple { + public: + using StateId = S; + using FilterState = FS; + + DefaultComposeStateTuple() + : state_pair_(kNoStateId, kNoStateId), fs_(FilterState::NoState()) {} + + DefaultComposeStateTuple(StateId s1, StateId s2, const FilterState &fs) + : state_pair_(s1, s2), fs_(fs) {} + + StateId StateId1() const { return state_pair_.first; } + + StateId StateId2() const { return state_pair_.second; } + + FilterState GetFilterState() const { return fs_; } + + const std::pair &StatePair() const { return state_pair_; } + + friend bool operator==(const DefaultComposeStateTuple &x, + const DefaultComposeStateTuple &y) { + return (&x == &y) || (x.state_pair_ == y.state_pair_ && x.fs_ == y.fs_); + } + + size_t Hash() const { + return static_cast(StateId1()) + + static_cast(StateId2()) * 7853u + + GetFilterState().Hash() * 7867u; + } + + private: + std::pair state_pair_; + FilterState fs_; // State of composition filter. +}; + +// Specialization for TrivialFilterState that does not explicitely store the +// filter state since it is always the unique non-blocking state. +template +class DefaultComposeStateTuple { + public: + using StateId = S; + using FilterState = TrivialFilterState; + + DefaultComposeStateTuple() + : state_pair_(kNoStateId, kNoStateId) {} + + DefaultComposeStateTuple(StateId s1, StateId s2, const FilterState &) + : state_pair_(s1, s2) {} + + StateId StateId1() const { return state_pair_.first; } + + StateId StateId2() const { return state_pair_.second; } + + FilterState GetFilterState() const { return FilterState(true); } + + const std::pair &StatePair() const { return state_pair_; } + + friend bool operator==(const DefaultComposeStateTuple &x, + const DefaultComposeStateTuple &y) { + return (&x == &y) || (x.state_pair_ == y.state_pair_); + } + + size_t Hash() const { return StateId1() + StateId2() * 7853; } + + private: + std::pair state_pair_; +}; + +// Hashing of composition state tuples. +template +class ComposeHash { + public: + size_t operator()(const T &t) const { return t.Hash(); } +}; + +// A HashStateTable over composition tuples. +template , + typename StateTable = + CompactHashStateTable>> +class GenericComposeStateTable : public StateTable { + public: + using StateId = typename Arc::StateId; + + GenericComposeStateTable(const Fst &fst1, const Fst &fst2) {} + + GenericComposeStateTable(const Fst &fst1, const Fst &fst2, + size_t table_size) + : StateTable(table_size) {} + + constexpr bool Error() const { return false; } + + private: + GenericComposeStateTable &operator=(const GenericComposeStateTable &table) = + delete; +}; + +// Fingerprint for general composition tuples. +template +class ComposeFingerprint { + public: + using StateId = typename StateTuple::StateId; + + // Required but suboptimal constructor. + ComposeFingerprint() : mult1_(8192), mult2_(8192) { + LOG(WARNING) << "TupleFingerprint: # of FST states should be provided."; + } + + // Constructor is provided the sizes of the input FSTs. + ComposeFingerprint(StateId nstates1, StateId nstates2) + : mult1_(nstates1), mult2_(nstates1 * nstates2) {} + + size_t operator()(const StateTuple &tuple) { + return tuple.StateId1() + tuple.StateId2() * mult1_ + + tuple.GetFilterState().Hash() * mult2_; + } + + private: + const ssize_t mult1_; + const ssize_t mult2_; +}; + +// Useful when the first composition state determines the tuple. +template +class ComposeState1Fingerprint { + public: + size_t operator()(const StateTuple &tuple) { return tuple.StateId1(); } +}; + +// Useful when the second composition state determines the tuple. +template +class ComposeState2Fingerprint { + public: + size_t operator()(const StateTuple &tuple) { return tuple.StateId2(); } +}; + +// A VectorStateTable over composition tuples. This can be used when the +// product of number of states in FST1 and FST2 (and the composition filter +// state hash) is manageable. If the FSTs are not expanded FSTs, they will +// first have their states counted. +template +class ProductComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + ProductComposeStateTable(const Fst &fst1, const Fst &fst2, + size_t table_size = 0) + : StateTable(new ComposeFingerprint(CountStates(fst1), + CountStates(fst2)), + table_size) {} + + ProductComposeStateTable( + const ProductComposeStateTable &table) + : StateTable(new ComposeFingerprint(table.Fingerprint())) {} + + constexpr bool Error() const { return false; } + + private: + ProductComposeStateTable &operator=(const ProductComposeStateTable &table) = + delete; +}; + +// A vector-backed table over composition tuples which can be used when the +// first FST is a string (i.e., satisfies kString property) and the second is +// deterministic and epsilon-free. It should be used with a composition filter +// that creates at most one filter state per tuple under these conditions (e.g., +// SequenceComposeFilter or MatchComposeFilter). +template +class StringDetComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + StringDetComposeStateTable(const Fst &fst1, const Fst &fst2) + : error_(false) { + static constexpr auto props2 = kIDeterministic | kNoIEpsilons; + if (fst1.Properties(kString, true) != kString) { + FSTERROR() << "StringDetComposeStateTable: 1st FST is not a string"; + error_ = true; + } else if (fst2.Properties(props2, true) != props2) { + FSTERROR() << "StringDetComposeStateTable: 2nd FST is not deterministic " + "and epsilon-free"; + error_ = true; + } + } + + StringDetComposeStateTable( + const StringDetComposeStateTable &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + StringDetComposeStateTable &operator=(const StringDetComposeStateTable &) = + delete; +}; + +// A vector-backed table over composition tuples which can be used when the +// first FST is deterministic and epsilon-free and the second is a string (i.e., +// satisfies kString). It should be used with a composition filter that creates +// at most one filter state per tuple under these conditions (e.g., +// SequenceComposeFilter or MatchComposeFilter). +template +class DetStringComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + DetStringComposeStateTable(const Fst &fst1, const Fst &fst2) + : error_(false) { + static constexpr auto props = kODeterministic | kNoOEpsilons; + if (fst1.Properties(props, true) != props) { + FSTERROR() << "StringDetComposeStateTable: 1st FST is not " + << "input-deterministic and epsilon-free"; + error_ = true; + } else if (fst2.Properties(kString, true) != kString) { + FSTERROR() << "DetStringComposeStateTable: 2nd FST is not a string"; + error_ = true; + } + } + + DetStringComposeStateTable( + const DetStringComposeStateTable &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + DetStringComposeStateTable &operator=(const DetStringComposeStateTable &) = + delete; +}; + +// An erasable table over composition tuples. The Erase(StateId) method can be +// called if the user either is sure that composition will never return to that +// tuple or doesn't care that if it does, it is assigned a new state ID. +template +class ErasableComposeStateTable + : public ErasableStateTable> { + public: + ErasableComposeStateTable(const Fst &fst1, const Fst &fst2) {} + + constexpr bool Error() const { return false; } + + private: + ErasableComposeStateTable &operator=(const ErasableComposeStateTable &table) = + delete; +}; + +} // namespace fst + +#endif // FST_STATE_TABLE_H_ diff --git a/projects/llm_framework/include/fst/statesort.h b/projects/llm_framework/include/fst/statesort.h new file mode 100644 index 00000000..346c7d32 --- /dev/null +++ b/projects/llm_framework/include/fst/statesort.h @@ -0,0 +1,74 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to sort states of an FST. + +#ifndef FST_STATESORT_H_ +#define FST_STATESORT_H_ + +#include +#include + +#include + +#include + + +namespace fst { + +// Sorts the input states of an FST. order[i] gives the the state ID after +// sorting that corresponds to the state ID i before sorting; it must +// therefore be a permutation of the input FST's states ID sequence. +template +void StateSort(MutableFst *fst, + const std::vector &order) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (order.size() != fst->NumStates()) { + FSTERROR() << "StateSort: Bad order vector size: " << order.size(); + fst->SetProperties(kError, kError); + return; + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kStateSortProperties, false); + std::vector done(order.size(), false); + std::vector arcsa; + std::vector arcsb; + fst->SetStart(order[fst->Start()]); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + auto s1 = siter.Value(); + StateId s2; + if (done[s1]) continue; + auto final1 = fst->Final(s1); + auto final2 = Weight::Zero(); + arcsa.clear(); + for (ArcIterator> aiter(*fst, s1); !aiter.Done(); + aiter.Next()) { + arcsa.push_back(aiter.Value()); + } + for (; !done[s1]; s1 = s2, final1 = final2, std::swap(arcsa, arcsb)) { + s2 = order[s1]; + if (!done[s2]) { + final2 = fst->Final(s2); + arcsb.clear(); + for (ArcIterator> aiter(*fst, s2); !aiter.Done(); + aiter.Next()) { + arcsb.push_back(aiter.Value()); + } + } + fst->SetFinal(s2, final1); + fst->DeleteArcs(s2); + for (auto arc : arcsa) { // Copy intended. + arc.nextstate = order[arc.nextstate]; + fst->AddArc(s2, arc); + } + done[s1] = true; + } + } + fst->SetProperties(props, kFstProperties); +} + +} // namespace fst + +#endif // FST_STATESORT_H_ diff --git a/projects/llm_framework/include/fst/string-weight.h b/projects/llm_framework/include/fst/string-weight.h new file mode 100644 index 00000000..fb3e70f1 --- /dev/null +++ b/projects/llm_framework/include/fst/string-weight.h @@ -0,0 +1,807 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// String weight set and associated semiring operation definitions. + +#ifndef FST_STRING_WEIGHT_H_ +#define FST_STRING_WEIGHT_H_ + +#include + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +constexpr int kStringInfinity = -1; // Label for the infinite string. +constexpr int kStringBad = -2; // Label for a non-string. +constexpr char kStringSeparator = '_'; // Label separator in strings. + +// Determines whether to use left or right string semiring. Includes a +// 'restricted' version that signals an error if proper prefixes/suffixes +// would otherwise be returned by Plus, useful with various +// algorithms that require functional transducer input with the +// string semirings. +enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1, STRING_RESTRICT = 2 }; + +constexpr StringType ReverseStringType(StringType s) { + return s == STRING_LEFT ? STRING_RIGHT + : (s == STRING_RIGHT ? STRING_LEFT : STRING_RESTRICT); +} + +template +class StringWeightIterator; +template +class StringWeightReverseIterator; + +// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon) +template +class StringWeight { + public: + using Label = Label_; + using ReverseWeight = StringWeight; + using Iterator = StringWeightIterator; + using ReverseIterator = StringWeightReverseIterator; + + friend class StringWeightIterator; + friend class StringWeightReverseIterator; + + StringWeight() {} + + template + StringWeight(const Iterator &begin, const Iterator &end) { + for (auto iter = begin; iter != end; ++iter) PushBack(*iter); + } + + explicit StringWeight(Label label) { PushBack(label); } + + static const StringWeight &Zero() { + static const auto *const zero = new StringWeight(Label(kStringInfinity)); + return *zero; + } + + static const StringWeight &One() { + static const auto *const one = new StringWeight(); + return *one; + } + + static const StringWeight &NoWeight() { + static const auto *const no_weight = new StringWeight(Label(kStringBad)); + return *no_weight; + } + + static const string &Type() { + static const string *const type = new string( + S == STRING_LEFT + ? "left_string" + : (S == STRING_RIGHT ? "right_string" : "restricted_string")); + return *type; + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + StringWeight Quantize(float delta = kDelta) const { return *this; } + + ReverseWeight Reverse() const; + + static constexpr uint64 Properties() { + return kIdempotent | + (S == STRING_LEFT ? kLeftSemiring + : (S == STRING_RIGHT + ? kRightSemiring + : /* S == STRING_RESTRICT */ kLeftSemiring | + kRightSemiring)); + } + + // These operations combined with the StringWeightIterator and + // StringWeightReverseIterator provide the access and mutation of the string + // internal elements. + + // Clear existing StringWeight. + void Clear() { + first_ = 0; + rest_.clear(); + } + + size_t Size() const { return first_ ? rest_.size() + 1 : 0; } + + void PushFront(Label label) { + if (first_) rest_.push_front(first_); + first_ = label; + } + + void PushBack(Label label) { + if (!first_) { + first_ = label; + } else { + rest_.push_back(label); + } + } + + private: + Label first_ = 0; // First label in string (0 if empty). + std::list; + + friend class ArcIterator>; + friend class StateIterator>; + + explicit SynchronizeFst( + const Fst &fst, + const SynchronizeFstOptions &opts = SynchronizeFstOptions()) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + SynchronizeFst(const SynchronizeFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this SynchronizeFst. See Fst<>::Copy() for further doc. + SynchronizeFst *Copy(bool safe = false) const override { + return new SynchronizeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + SynchronizeFst &operator=(const SynchronizeFst &) = delete; +}; + +// Specialization for SynchronizeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const SynchronizeFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for SynchronizeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const SynchronizeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void SynchronizeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Synchronizes a transducer. This version writes the synchronized result to a +// MutableFst. The result will be an equivalent FST that has the property that +// during the traversal of a path, the delay is either zero or strictly +// increasing, where the delay is the difference between the number of +// non-epsilon output labels and input labels along the path. +// +// For the algorithm to terminate, the input transducer must have bounded +// delay, i.e., the delay of every cycle must be zero. +// +// Complexity: +// +// - A has bounded delay: exponential. +// - A does not have bounded delay: does not terminate. +// +// For more information, see: +// +// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and +// algorithms. International Journal of Computer Science 14(6): 957-982. +template +void Synchronize(const Fst &ifst, MutableFst *ofst) { + // Caches only the last state for fastest copy. + const SynchronizeFstOptions opts(FLAGS_fst_default_cache_gc, 0); + *ofst = SynchronizeFst(ifst, opts); +} + +} // namespace fst + +#endif // FST_SYNCHRONIZE_H_ diff --git a/projects/llm_framework/include/fst/test-properties.h b/projects/llm_framework/include/fst/test-properties.h new file mode 100644 index 00000000..677ed01b --- /dev/null +++ b/projects/llm_framework/include/fst/test-properties.h @@ -0,0 +1,246 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions to manipulate and test property bits. + +#ifndef FST_TEST_PROPERTIES_H_ +#define FST_TEST_PROPERTIES_H_ + +#include + +#include +#include + +#include +#include + + +DECLARE_bool(fst_verify_properties); + +namespace fst { +// namespace internal { + +// For a binary property, the bit is always returned set. For a trinary (i.e., +// two-bit) property, both bits are returned set iff either corresponding input +// bit is set. +inline uint64 KnownProperties(uint64 props) { + return kBinaryProperties | (props & kTrinaryProperties) | + ((props & kPosTrinaryProperties) << 1) | + ((props & kNegTrinaryProperties) >> 1); +} + +// Tests compatibility between two sets of properties. +inline bool CompatProperties(uint64 props1, uint64 props2) { + const auto known_props1 = KnownProperties(props1); + const auto known_props2 = KnownProperties(props2); + const auto known_props = known_props1 & known_props2; + const auto incompat_props = (props1 & known_props) ^ (props2 & known_props); + if (incompat_props) { + uint64 prop = 1; + for (int i = 0; i < 64; ++i, prop <<= 1) { + if (prop & incompat_props) { + LOG(ERROR) << "CompatProperties: Mismatch: " << PropertyNames[i] + << ": props1 = " << (props1 & prop ? "true" : "false") + << ", props2 = " << (props2 & prop ? "true" : "false"); + } + } + return false; + } else { + return true; + } +} + +// Computes FST property values defined in properties.h. The value of each +// property indicated in the mask will be determined and returned (these will +// never be unknown here). In the course of determining the properties +// specifically requested in the mask, certain other properties may be +// determined (those with little additional expense) and their values will be +// returned as well. The complete set of known properties (whether true or +// false) determined by this operation will be assigned to the the value pointed +// to by KNOWN. If 'use_stored' is true, pre-computed FST properties may be used +// when possible. 'mask & required_mask' is used to determine whether the stored +// propertoes can be used. This routine is seldom called directly; instead it is +// used to implement fst.Properties(mask, true). +template +uint64 ComputeProperties(const Fst &fst, uint64 mask, uint64 *known, + bool use_stored) { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + const auto fst_props = fst.Properties(kFstProperties, false); // FST-stored. + // Check stored FST properties first if allowed. + if (use_stored) { + const auto known_props = KnownProperties(fst_props); + // If FST contains required info, return it. + if ((known_props & mask) == mask) { + if (known) *known = known_props; + return fst_props; + } + } + // Computes (trinary) properties explicitly. + // Initialize with binary properties (already known). + uint64 comp_props = fst_props & kBinaryProperties; + // Computes these trinary properties with a DFS. We compute only those that + // need a DFS here, since we otherwise would like to avoid a DFS since its + // stack could grow large. + uint64 dfs_props = kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible; + std::vector scc; + if (mask & (dfs_props | kWeightedCycles | kUnweightedCycles)) { + SccVisitor scc_visitor(&scc, nullptr, nullptr, &comp_props); + DfsVisit(fst, &scc_visitor); + } + // Computes any remaining trinary properties via a state and arcs iterations + if (mask & ~(kBinaryProperties | dfs_props)) { + comp_props |= kAcceptor | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kTopSorted | + kString; + if (mask & (kIDeterministic | kNonIDeterministic)) { + comp_props |= kIDeterministic; + } + if (mask & (kODeterministic | kNonODeterministic)) { + comp_props |= kODeterministic; + } + if (mask & (dfs_props | kWeightedCycles | kUnweightedCycles)) { + comp_props |= kUnweightedCycles; + } + std::unique_ptr> ilabels; + std::unique_ptr> olabels; + StateId nfinal = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + Arc prev_arc; + // Creates these only if we need to. + if (mask & (kIDeterministic | kNonIDeterministic)) { + ilabels.reset(new std::unordered_set &fst1, const Fst &fst2) { + VLOG(1) << "Check FSTs for sanity (including property bits)."; + CHECK(Verify(fst1)); + CHECK(Verify(fst2)); + + // Ensures seed used once per instantiation. + static UniformArcSelector uniform_selector(seed_); + RandGenOptions> opts(uniform_selector, + kRandomPathLength); + return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts); + } + + // Tests FSA is unambiguous + bool Unambiguous(const Fst &fst) { + VectorFst sfst, dfst; + VectorFst lfst1, lfst2; + Map(fst, &sfst, RmWeightMapper()); + Determinize(sfst, &dfst); + Map(fst, &lfst1, RmWeightMapper()); + Map(dfst, &lfst2, RmWeightMapper()); + return Equiv(lfst1, lfst2); + } + + // Ensures input-epsilon free transducers fst1 and fst2 have the + // same domain and that for each string pair '(is, os)' in fst1, + // '(is, os)' is the minimum weight match to 'is' in fst2. + template + bool MinRelated(const Fst &fst1, const Fst &fst2) { + // Same domain + VectorFst P1(fst1), P2(fst2); + Project(&P1, PROJECT_INPUT); + Project(&P2, PROJECT_INPUT); + if (!Equiv(P1, P2)) { + LOG(ERROR) << "Inputs not equivalent"; + return false; + } + + // Ensures seed used once per instantiation. + static UniformArcSelector uniform_selector(seed_); + RandGenOptions> opts(uniform_selector, + kRandomPathLength); + + VectorFst path, paths1, paths2; + for (ssize_t n = 0; n < kNumRandomPaths; ++n) { + RandGen(fst1, &path, opts); + Invert(&path); + Map(&path, RmWeightMapper()); + Compose(path, fst2, &paths1); + Weight sum1 = ShortestDistance(paths1); + Compose(paths1, path, &paths2); + Weight sum2 = ShortestDistance(paths2); + if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) { + LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2; + return false; + } + } + return true; + } + + // Tests ShortestDistance(A - P) >= + // ShortestDistance(A) times Threshold. + template + bool PruneEquiv(const Fst &fst, const Fst &pfst, Weight threshold) { + VLOG(1) << "Check FSTs for sanity (including property bits)."; + CHECK(Verify(fst)); + CHECK(Verify(pfst)); + + DifferenceFst D(fst, DeterminizeFst(RmEpsilonFst( + ArcMapFst>( + pfst, RmWeightMapper())))); + Weight sum1 = Times(ShortestDistance(fst), threshold); + Weight sum2 = ShortestDistance(D); + return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta); + } + + // Random seed. + int seed_; + // FST with no states + VectorFst zero_fst_; + // FST with one state that accepts epsilon. + VectorFst one_fst_; + // FST with one state that accepts all strings. + VectorFst univ_fst_; + // Generates weights used in testing. + WeightGenerator *weight_generator_; + // Maximum random path length. + static const int kRandomPathLength; + // Number of random paths to explore. + static const int kNumRandomPaths; + // Maximum number of nshortest paths. + static const int kNumRandomShortestPaths; + // Maximum number of nshortest states. + static const int kNumShortestStates; + // Delta for equivalence tests. + static const float kTestDelta; + + WeightedTester(const WeightedTester &) = delete; + WeightedTester &operator=(const WeightedTester &) = delete; +}; + +template +const int WeightedTester::kRandomPathLength = 25; + +template +const int WeightedTester::kNumRandomPaths = 100; + +template +const int WeightedTester::kNumRandomShortestPaths = 100; + +template +const int WeightedTester::kNumShortestStates = 10000; + +template +const float WeightedTester::kTestDelta = .05; + +// This class tests a variety of identities and properties that must +// hold for various algorithms on unweighted FSAs and that are not tested +// by WeightedTester. Only the specialization does anything interesting. +template +class UnweightedTester { + public: + UnweightedTester(const Fst &zero_fsa, const Fst &one_fsa, + const Fst &univ_fsa) {} + + void Test(const Fst &A1, const Fst &A2, const Fst &A3) {} +}; + +// Specialization for StdArc. This should work for any commutative, +// idempotent semiring when restricted to the unweighted case +// (being isomorphic to the boolean semiring). +template <> +class UnweightedTester { + public: + typedef StdArc Arc; + typedef Arc::Label Label; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + + UnweightedTester(const Fst &zero_fsa, const Fst &one_fsa, + const Fst &univ_fsa) + : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {} + + void Test(const Fst &A1, const Fst &A2, const Fst &A3) { + TestRational(A1, A2, A3); + TestIntersect(A1, A2, A3); + TestOptimize(A1); + } + + private: + // Tests rational operations with identities + void TestRational(const Fst &A1, const Fst &A2, + const Fst &A3) { + { + VLOG(1) << "Check the union contains its arguments (destructive)."; + VectorFst U(A1); + Union(&U, A2); + + CHECK(Subset(A1, U)); + CHECK(Subset(A2, U)); + } + + { + VLOG(1) << "Check the union contains its arguments (delayed)."; + UnionFst U(A1, A2); + + CHECK(Subset(A1, U)); + CHECK(Subset(A2, U)); + } + + { + VLOG(1) << "Check if A^n c A* (destructive)."; + VectorFst C(one_fsa_); + int n = rand() % 5; + for (int i = 0; i < n; ++i) Concat(&C, A1); + + VectorFst S(A1); + Closure(&S, CLOSURE_STAR); + CHECK(Subset(C, S)); + } + + { + VLOG(1) << "Check if A^n c A* (delayed)."; + int n = rand() % 5; + Fst *C = new VectorFst(one_fsa_); + for (int i = 0; i < n; ++i) { + ConcatFst *F = new ConcatFst(*C, A1); + delete C; + C = F; + } + ClosureFst S(A1, CLOSURE_STAR); + CHECK(Subset(*C, S)); + delete C; + } + } + + // Tests intersect-based operations. + void TestIntersect(const Fst &A1, const Fst &A2, + const Fst &A3) { + VectorFst S1(A1); + VectorFst S2(A2); + VectorFst S3(A3); + + ILabelCompare comp; + + ArcSort(&S1, comp); + ArcSort(&S2, comp); + ArcSort(&S3, comp); + + { + VLOG(1) << "Check the intersection is contained in its arguments."; + IntersectFst I1(S1, S2); + CHECK(Subset(I1, S1)); + CHECK(Subset(I1, S2)); + } + + { + VLOG(1) << "Check union distributes over intersection."; + IntersectFst I1(S1, S2); + UnionFst U1(I1, S3); + + UnionFst U2(S1, S3); + UnionFst U3(S2, S3); + ArcSortFst> S4(U3, comp); + IntersectFst I2(U2, S4); + + CHECK(Equiv(U1, I2)); + } + + VectorFst C1; + VectorFst C2; + Complement(S1, &C1); + Complement(S2, &C2); + ArcSort(&C1, comp); + ArcSort(&C2, comp); + + { + VLOG(1) << "Check S U S' = Sigma*"; + UnionFst U(S1, C1); + CHECK(Equiv(U, univ_fsa_)); + } + + { + VLOG(1) << "Check S n S' = {}"; + IntersectFst I(S1, C1); + CHECK(Equiv(I, zero_fsa_)); + } + + { + VLOG(1) << "Check (S1' U S2') == (S1 n S2)'"; + UnionFst U(C1, C2); + + IntersectFst I(S1, S2); + VectorFst C3; + Complement(I, &C3); + CHECK(Equiv(U, C3)); + } + + { + VLOG(1) << "Check (S1' n S2') == (S1 U S2)'"; + IntersectFst I(C1, C2); + + UnionFst U(S1, S2); + VectorFst C3; + Complement(U, &C3); + CHECK(Equiv(I, C3)); + } + } + + // Tests optimization operations + void TestOptimize(const Fst &A) { + { + VLOG(1) << "Check determinized FSA is equivalent to its input."; + DeterminizeFst D(A); + CHECK(Equiv(A, D)); + } + + { + VLOG(1) << "Check disambiguated FSA is equivalent to its input."; + VectorFst R(A), D; + RmEpsilon(&R); + + Disambiguate(R, &D); + CHECK(Equiv(R, D)); + } + + { + VLOG(1) << "Check minimized FSA is equivalent to its input."; + int n; + { + RmEpsilonFst R(A); + DeterminizeFst D(R); + VectorFst M(D); + Minimize(&M, static_cast *>(nullptr), kDelta); + CHECK(Equiv(A, M)); + n = M.NumStates(); + } + + if (n) { // Skip test if A is the empty machine + VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the" + << " same number of states as Brozozowski's algorithm"; + VectorFst R; + Reverse(A, &R); + RmEpsilon(&R); + DeterminizeFst DR(R); + VectorFst RD; + Reverse(DR, &RD); + DeterminizeFst DRD(RD); + VectorFst M(DRD); + CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition + // to the initial state + } + } + } + + // Tests if two FSAS are equivalent. + bool Equiv(const Fst &fsa1, const Fst &fsa2) { + VLOG(1) << "Check FSAs for sanity (including property bits)."; + CHECK(Verify(fsa1)); + CHECK(Verify(fsa2)); + + VectorFst vfsa1(fsa1); + VectorFst vfsa2(fsa2); + RmEpsilon(&vfsa1); + RmEpsilon(&vfsa2); + DeterminizeFst dfa1(vfsa1); + DeterminizeFst dfa2(vfsa2); + + // Test equivalence using union-find algorithm + bool equiv1 = Equivalent(dfa1, dfa2); + + // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty + ILabelCompare comp; + VectorFst sdfa1(dfa1); + ArcSort(&sdfa1, comp); + VectorFst sdfa2(dfa2); + ArcSort(&sdfa2, comp); + + DifferenceFst dfsa1(sdfa1, sdfa2); + DifferenceFst dfsa2(sdfa2, sdfa1); + + VectorFst ufsa(dfsa1); + Union(&ufsa, dfsa2); + Connect(&ufsa); + bool equiv2 = ufsa.NumStates() == 0; + + // Check two equivalence tests match + CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2)); + + return equiv1; + } + + // Tests if FSA1 is a subset of FSA2 (disregarding weights). + bool Subset(const Fst &fsa1, const Fst &fsa2) { + VLOG(1) << "Check FSAs (incl. property bits) for sanity"; + CHECK(Verify(fsa1)); + CHECK(Verify(fsa2)); + + VectorFst vfsa1; + VectorFst vfsa2; + RmEpsilon(&vfsa1); + RmEpsilon(&vfsa2); + ILabelCompare comp; + ArcSort(&vfsa1, comp); + ArcSort(&vfsa2, comp); + IntersectFst ifsa(vfsa1, vfsa2); + DeterminizeFst dfa1(vfsa1); + DeterminizeFst dfa2(ifsa); + return Equivalent(dfa1, dfa2); + } + + // Returns complement Fsa + void Complement(const Fst &ifsa, MutableFst *ofsa) { + RmEpsilonFst rfsa(ifsa); + DeterminizeFst dfa(rfsa); + DifferenceFst cfsa(univ_fsa_, dfa); + *ofsa = cfsa; + } + + // FSA with no states + VectorFst zero_fsa_; + + // FSA with one state that accepts epsilon. + VectorFst one_fsa_; + + // FSA with one state that accepts all strings. + VectorFst univ_fsa_; +}; + +// This class tests a variety of identities and properties that must +// hold for various FST algorithms. It randomly generates FSTs, using +// function object 'weight_generator' to select weights. 'WeightTester' +// and 'UnweightedTester' are then called. +template +class AlgoTester { + public: + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + AlgoTester(WeightGenerator generator, int seed) + : weight_generator_(generator) { + one_fst_.AddState(); + one_fst_.SetStart(0); + one_fst_.SetFinal(0, Weight::One()); + + univ_fst_.AddState(); + univ_fst_.SetStart(0); + univ_fst_.SetFinal(0, Weight::One()); + for (int i = 0; i < kNumRandomLabels; ++i) + univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0)); + + weighted_tester_ = new WeightedTester( + seed, zero_fst_, one_fst_, univ_fst_, &weight_generator_); + + unweighted_tester_ = + new UnweightedTester(zero_fst_, one_fst_, univ_fst_); + } + + ~AlgoTester() { + delete weighted_tester_; + delete unweighted_tester_; + } + + void MakeRandFst(MutableFst *fst) { + RandFst(kNumRandomStates, kNumRandomArcs, + kNumRandomLabels, kAcyclicProb, + &weight_generator_, fst); + } + + void Test() { + VLOG(1) << "weight type = " << Weight::Type(); + + for (int i = 0; i < FLAGS_repeat; ++i) { + // Random transducers + VectorFst T1; + VectorFst T2; + VectorFst T3; + MakeRandFst(&T1); + MakeRandFst(&T2); + MakeRandFst(&T3); + weighted_tester_->Test(T1, T2, T3); + + VectorFst A1(T1); + VectorFst A2(T2); + VectorFst A3(T3); + Project(&A1, PROJECT_OUTPUT); + Project(&A2, PROJECT_INPUT); + Project(&A3, PROJECT_INPUT); + ArcMap(&A1, rm_weight_mapper_); + ArcMap(&A2, rm_weight_mapper_); + ArcMap(&A3, rm_weight_mapper_); + unweighted_tester_->Test(A1, A2, A3); + } + } + + private: + // Generates weights used in testing. + WeightGenerator weight_generator_; + + // FST with no states + VectorFst zero_fst_; + + // FST with one state that accepts epsilon. + VectorFst one_fst_; + + // FST with one state that accepts all strings. + VectorFst univ_fst_; + + // Tests weighted FSTs + WeightedTester *weighted_tester_; + + // Tests unweighted FSTs + UnweightedTester *unweighted_tester_; + + // Mapper to remove weights from an Fst + RmWeightMapper rm_weight_mapper_; + + // Maximum number of states in random test Fst. + static const int kNumRandomStates; + + // Maximum number of arcs in random test Fst. + static const int kNumRandomArcs; + + // Number of alternative random labels. + static const int kNumRandomLabels; + + // Probability to force an acyclic Fst + static const float kAcyclicProb; + + // Maximum random path length. + static const int kRandomPathLength; + + // Number of random paths to explore. + static const int kNumRandomPaths; + + AlgoTester(const AlgoTester &) = delete; + AlgoTester &operator=(const AlgoTester &) = delete; +}; + +template +const int AlgoTester::kNumRandomStates = 10; + +template +const int AlgoTester::kNumRandomArcs = 25; + +template +const int AlgoTester::kNumRandomLabels = 5; + +template +const float AlgoTester::kAcyclicProb = .25; + +template +const int AlgoTester::kRandomPathLength = 25; + +template +const int AlgoTester::kNumRandomPaths = 100; + +} // namespace fst + +#endif // FST_TEST_ALGO_TEST_H_ diff --git a/projects/llm_framework/include/fst/test/fst_test.h b/projects/llm_framework/include/fst/test/fst_test.h new file mode 100644 index 00000000..7d536d90 --- /dev/null +++ b/projects/llm_framework/include/fst/test/fst_test.h @@ -0,0 +1,318 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Regression test for FST classes. + +#ifndef FST_TEST_FST_TEST_H_ +#define FST_TEST_FST_TEST_H_ + +#include +#include +#include +#include +#include + +DECLARE_string(tmpdir); + +namespace fst { + +// This tests an Fst F that is assumed to have a copy method from an +// arbitrary Fst. Some test functions make further assumptions mostly +// obvious from their name. These tests are written as member temple +// functions that take a test fst as its argument so that different +// Fsts in the interface hierarchy can be tested separately and so +// that we can instantiate only those tests that make sense for a +// particular Fst. +template +class FstTester { + public: + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + FstTester() { + VectorFst vfst; + InitFst(&vfst, 128); + testfst_ = new F(vfst); + } + + explicit FstTester(F *testfst) : testfst_(testfst) {} + + ~FstTester() { delete testfst_; } + + // This verifies the contents described in InitFst() using + // methods defined in a generic Fst. + template + void TestBase(const G &fst) const { + CHECK(Verify(fst)); + CHECK_EQ(fst.Start(), 0); + StateId ns = 0; + StateIterator siter(fst); + Matcher matcher(fst, MATCH_INPUT); + MatchType match_type = matcher.Type(true); + for (; !siter.Done(); siter.Next()) { + } + for (siter.Reset(); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + matcher.SetState(s); + CHECK_EQ(fst.Final(s), NthWeight(s)); + size_t na = 0; + ArcIterator aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + } + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + ++na; + const Arc &arc = aiter.Value(); + CHECK_EQ(arc.ilabel, na); + CHECK_EQ(arc.olabel, 0); + CHECK_EQ(arc.weight, NthWeight(na)); + CHECK_EQ(arc.nextstate, s); + if (match_type == MATCH_INPUT) { + CHECK(matcher.Find(arc.ilabel)); + CHECK_EQ(matcher.Value().ilabel, arc.ilabel); + } + } + CHECK_EQ(na, s); + CHECK_EQ(na, aiter.Position()); + CHECK_EQ(fst.NumArcs(s), s); + CHECK_EQ(fst.NumInputEpsilons(s), 0); + CHECK_EQ(fst.NumOutputEpsilons(s), s); + CHECK(!matcher.Find(s + 1)); // out-of-range + CHECK(!matcher.Find(kNoLabel)); // no explicit epsilons + CHECK(matcher.Find(0)); + CHECK_EQ(matcher.Value().ilabel, kNoLabel); // implicit epsilon loop + ++ns; + } + CHECK(fst.Properties(kNotAcceptor, true)); + CHECK(fst.Properties(kOEpsilons, true)); + } + + void TestBase() const { TestBase(*testfst_); } + + // This verifies methods specfic to an ExpandedFst. + template + void TestExpanded(const G &fst) const { + StateId ns = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + ++ns; + } + CHECK_EQ(fst.NumStates(), ns); + CHECK(fst.Properties(kExpanded, false)); + } + + void TestExpanded() const { TestExpanded(*testfst_); } + + // This verifies methods specific to a MutableFst. + template + void TestMutable(G *fst) const { + for (StateIterator siter(*fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + size_t na = 0; + size_t ni = fst->NumInputEpsilons(s); + MutableArcIterator aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + } + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + ++na; + Arc arc = aiter.Value(); + arc.ilabel = 0; + aiter.SetValue(arc); + arc = aiter.Value(); + CHECK_EQ(arc.ilabel, 0); + CHECK_EQ(fst->NumInputEpsilons(s), ni + 1); + arc.ilabel = na; + aiter.SetValue(arc); + CHECK_EQ(fst->NumInputEpsilons(s), ni); + } + } + + G *cfst1 = fst->Copy(); + cfst1->DeleteStates(); + CHECK_EQ(cfst1->NumStates(), 0); + delete cfst1; + + G *cfst2 = fst->Copy(); + for (StateIterator siter(*cfst2); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + cfst2->DeleteArcs(s); + CHECK_EQ(cfst2->NumArcs(s), 0); + CHECK_EQ(cfst2->NumInputEpsilons(s), 0); + CHECK_EQ(cfst2->NumOutputEpsilons(s), 0); + } + delete cfst2; + } + + void TestMutable() { TestMutable(testfst_); } + + // This verifies the copy methods. + template + void TestAssign(G *fst) const { + // Assignment from G + G afst1; + afst1 = *fst; + CHECK(Equal(*fst, afst1)); + + // Assignment from Fst + G afst2; + afst2 = *static_cast *>(fst); + CHECK(Equal(*fst, afst2)); + + // Assignment from self + afst2.operator=(afst2); + CHECK(Equal(*fst, afst2)); + } + + void TestAssign() { TestAssign(testfst_); } + + // This verifies the copy methods. + template + void TestCopy(const G &fst) const { + // Copy from G + G c1fst(fst); + TestBase(c1fst); + + // Copy from Fst + const G c2fst(static_cast &>(fst)); + TestBase(c2fst); + + // Copy from self + const G *c3fst = fst.Copy(); + TestBase(*c3fst); + delete c3fst; + } + + void TestCopy() const { TestCopy(*testfst_); } + + // This verifies the read/write methods. + template + void TestIO(const G &fst) const { + const string filename = FLAGS_tmpdir + "/test.fst"; + const string aligned = FLAGS_tmpdir + "/aligned.fst"; + { + // write/read + CHECK(fst.Write(filename)); + G *ffst = G::Read(filename); + CHECK(ffst); + TestBase(*ffst); + delete ffst; + } + + { + // generic read/cast/test + Fst *gfst = Fst::Read(filename); + CHECK(gfst); + G *dfst = static_cast(gfst); + TestBase(*dfst); + + // generic write/read/test + CHECK(gfst->Write(filename)); + Fst *hfst = Fst::Read(filename); + CHECK(hfst); + TestBase(*hfst); + delete gfst; + delete hfst; + } + + { + // check mmaping by first writing the file with the aligned attribute set + { + std::ofstream ostr(aligned); + FstWriteOptions opts; + opts.source = aligned; + opts.align = true; + CHECK(fst.Write(ostr, opts)); + } + std::ifstream istr(aligned); + FstReadOptions opts; + opts.mode = FstReadOptions::ReadMode("map"); + opts.source = aligned; + G *gfst = G::Read(istr, opts); + CHECK(gfst); + TestBase(*gfst); + delete gfst; + } + + // check mmaping of unaligned files to make sure it does not fail. + { + { + std::ofstream ostr(aligned); + FstWriteOptions opts; + opts.source = aligned; + opts.align = false; + CHECK(fst.Write(ostr, opts)); + } + std::ifstream istr(aligned); + FstReadOptions opts; + opts.mode = FstReadOptions::ReadMode("map"); + opts.source = aligned; + G *gfst = G::Read(istr, opts); + CHECK(gfst); + TestBase(*gfst); + delete gfst; + } + + // expanded write/read/test + if (fst.Properties(kExpanded, false)) { + ExpandedFst *efst = ExpandedFst::Read(filename); + CHECK(efst); + TestBase(*efst); + TestExpanded(*efst); + delete efst; + } + + // mutable write/read/test + if (fst.Properties(kMutable, false)) { + MutableFst *mfst = MutableFst::Read(filename); + CHECK(mfst); + TestBase(*mfst); + TestExpanded(*mfst); + TestMutable(mfst); + delete mfst; + } + } + + void TestIO() const { TestIO(*testfst_); } + + private: + // This constructs test FSTs. Given a mutable FST, will leave + // the FST as follows: + // (I) NumStates() = nstates + // (II) Start() = 0 + // (III) Final(s) = NthWeight(s) + // (IV) For state s: + // (a) NumArcs(s) == s + // (b) For ith arc of s: + // (1) ilabel = i + // (2) olabel = 0 + // (3) weight = NthWeight(i) + // (4) nextstate = s + void InitFst(MutableFst *fst, size_t nstates) const { + fst->DeleteStates(); + CHECK_GT(nstates, 0); + + for (StateId s = 0; s < nstates; ++s) { + fst->AddState(); + fst->SetFinal(s, NthWeight(s)); + for (size_t i = 1; i <= s; ++i) { + Arc arc(i, 0, NthWeight(i), s); + fst->AddArc(s, arc); + } + } + + fst->SetStart(0); + } + + // Generates One() + ... + One() (n times) + Weight NthWeight(int n) const { + Weight w = Weight::Zero(); + for (int i = 0; i < n; ++i) w = Plus(w, Weight::One()); + return w; + } + + F *testfst_; // what we're testing +}; + +} // namespace fst + +#endif // FST_TEST_FST_TEST_H_ diff --git a/projects/llm_framework/include/fst/test/rand-fst.h b/projects/llm_framework/include/fst/test/rand-fst.h new file mode 100644 index 00000000..f2f34c67 --- /dev/null +++ b/projects/llm_framework/include/fst/test/rand-fst.h @@ -0,0 +1,90 @@ +#ifndef FST_TEST_RAND_FST_H_ +#define FST_TEST_RAND_FST_H_ + +#include +#include +#include + +namespace fst { + +// Generates a random FST. +template +void RandFst(const int num_random_states, const int num_random_arcs, + const int num_random_labels, const float acyclic_prob, + WeightGenerator *weight_generator, MutableFst *fst) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + // Determines direction of the arcs wrt state numbering. This way we + // can force acyclicity when desired. + enum ArcDirection { + ANY_DIRECTION = 0, + FORWARD_DIRECTION = 1, + REVERSE_DIRECTION = 2, + NUM_DIRECTIONS = 3 + }; + + ArcDirection arc_direction = ANY_DIRECTION; + if (rand() / (RAND_MAX + 1.0) < acyclic_prob) + arc_direction = rand() % 2 ? FORWARD_DIRECTION : REVERSE_DIRECTION; + + fst->DeleteStates(); + StateId ns = rand() % num_random_states; + + if (ns == 0) return; + for (StateId s = 0; s < ns; ++s) fst->AddState(); + + StateId start = rand() % ns; + fst->SetStart(start); + + size_t na = rand() % num_random_arcs; + for (size_t n = 0; n < na; ++n) { + StateId s = rand() % ns; + Arc arc; + arc.ilabel = rand() % num_random_labels; + arc.olabel = rand() % num_random_labels; + arc.weight = (*weight_generator)(); + arc.nextstate = rand() % ns; + + if ((arc_direction == FORWARD_DIRECTION || + arc_direction == REVERSE_DIRECTION) && + s == arc.nextstate) { + continue; // skips self-loops + } + + if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) || + (arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) { + StateId t = s; // reverses arcs + s = arc.nextstate; + arc.nextstate = t; + } + + fst->AddArc(s, arc); + } + + StateId nf = rand() % (ns + 1); + for (StateId n = 0; n < nf; ++n) { + StateId s = rand() % ns; + Weight final = (*weight_generator)(); + fst->SetFinal(s, final); + } + VLOG(1) << "Check FST for sanity (including property bits)."; + CHECK(Verify(*fst)); + + // Get/compute all properties. + uint64 props = fst->Properties(kFstProperties, true); + + // Select random set of properties to be unknown. + uint64 mask = 0; + for (int n = 0; n < 8; ++n) { + mask |= rand() & 0xff; + mask <<= 8; + } + mask &= ~kTrinaryProperties; + fst->SetProperties(props & ~mask, mask); +} + +} // namespace fst + +#endif // FST_TEST_RAND_FST_H_ diff --git a/projects/llm_framework/include/fst/test/weight-tester.h b/projects/llm_framework/include/fst/test/weight-tester.h new file mode 100644 index 00000000..b7de665e --- /dev/null +++ b/projects/llm_framework/include/fst/test/weight-tester.h @@ -0,0 +1,207 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utility class for regression testing of FST weights. + +#ifndef FST_TEST_WEIGHT_TESTER_H_ +#define FST_TEST_WEIGHT_TESTER_H_ + +#include +#include + +#include + +#include +#include + +namespace fst { + +// This class tests a variety of identities and properties that must +// hold for the Weight class to be well-defined. It calls function object +// WEIGHT_GENERATOR to select weights that are used in the tests. +template +class WeightTester { + public: + WeightTester(WeightGenerator generator) + : weight_generator_(std::move(generator)) {} + + void Test(int iterations, bool test_division = true) { + for (int i = 0; i < iterations; ++i) { + // Selects the test weights. + const Weight w1(weight_generator_()); + const Weight w2(weight_generator_()); + const Weight w3(weight_generator_()); + + VLOG(1) << "weight type = " << Weight::Type(); + VLOG(1) << "w1 = " << w1; + VLOG(1) << "w2 = " << w2; + VLOG(1) << "w3 = " << w3; + + TestSemiring(w1, w2, w3); + if (test_division) TestDivision(w1, w2); + TestReverse(w1, w2); + TestEquality(w1, w2, w3); + TestIO(w1); + TestCopy(w1); + } + } + + private: + // Note in the tests below we use ApproxEqual rather than == and add + // kDelta to inequalities where the weights might be inexact. + + // Tests (Plus, Times, Zero, One) defines a commutative semiring. + void TestSemiring(Weight w1, Weight w2, Weight w3) { + // Checks that the operations are closed. + CHECK(Plus(w1, w2).Member()); + CHECK(Times(w1, w2).Member()); + + // Checks that the operations are associative. + CHECK(ApproxEqual(Plus(w1, Plus(w2, w3)), Plus(Plus(w1, w2), w3))); + CHECK(ApproxEqual(Times(w1, Times(w2, w3)), Times(Times(w1, w2), w3))); + + // Checks the identity elements. + CHECK(Plus(w1, Weight::Zero()) == w1); + CHECK(Plus(Weight::Zero(), w1) == w1); + CHECK(Times(w1, Weight::One()) == w1); + CHECK(Times(Weight::One(), w1) == w1); + + // Check the no weight element. + CHECK(!Weight::NoWeight().Member()); + CHECK(!Plus(w1, Weight::NoWeight()).Member()); + CHECK(!Plus(Weight::NoWeight(), w1).Member()); + CHECK(!Times(w1, Weight::NoWeight()).Member()); + CHECK(!Times(Weight::NoWeight(), w1).Member()); + + // Checks that the operations commute. + CHECK(ApproxEqual(Plus(w1, w2), Plus(w2, w1))); + + if (Weight::Properties() & kCommutative) + CHECK(ApproxEqual(Times(w1, w2), Times(w2, w1))); + + // Checks Zero() is the annihilator. + CHECK(Times(w1, Weight::Zero()) == Weight::Zero()); + CHECK(Times(Weight::Zero(), w1) == Weight::Zero()); + + // Check Power(w, 0) is Weight::One() + CHECK(Power(w1, 0) == Weight::One()); + + // Check Power(w, 1) is w + CHECK(Power(w1, 1) == w1); + + // Check Power(w, 3) is Times(w, Times(w, w)) + CHECK(Power(w1, 3) == Times(w1, Times(w1, w1))); + + // Checks distributivity. + if (Weight::Properties() & kLeftSemiring) { + CHECK(ApproxEqual(Times(w1, Plus(w2, w3)), + Plus(Times(w1, w2), Times(w1, w3)))); + } + if (Weight::Properties() & kRightSemiring) + CHECK(ApproxEqual(Times(Plus(w1, w2), w3), + Plus(Times(w1, w3), Times(w2, w3)))); + + if (Weight::Properties() & kIdempotent) CHECK(Plus(w1, w1) == w1); + + if (Weight::Properties() & kPath) + CHECK(Plus(w1, w2) == w1 || Plus(w1, w2) == w2); + + // Ensure weights form a left or right semiring. + CHECK(Weight::Properties() & (kLeftSemiring | kRightSemiring)); + + // Check when Times() is commutative that it is marked as a semiring. + if (Weight::Properties() & kCommutative) + CHECK(Weight::Properties() & kSemiring); + } + + // Tests division operation. + void TestDivision(Weight w1, Weight w2) { + Weight p = Times(w1, w2); + + if (Weight::Properties() & kLeftSemiring) { + Weight d = Divide(p, w1, DIVIDE_LEFT); + if (d.Member()) CHECK(ApproxEqual(p, Times(w1, d))); + CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_LEFT).Member()); + CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_LEFT).Member()); + } + + if (Weight::Properties() & kRightSemiring) { + Weight d = Divide(p, w2, DIVIDE_RIGHT); + if (d.Member()) CHECK(ApproxEqual(p, Times(d, w2))); + CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_RIGHT).Member()); + CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_RIGHT).Member()); + } + + if (Weight::Properties() & kCommutative) { + Weight d = Divide(p, w1, DIVIDE_RIGHT); + if (d.Member()) CHECK(ApproxEqual(p, Times(d, w1))); + } + } + + // Tests reverse operation. + void TestReverse(Weight w1, Weight w2) { + typedef typename Weight::ReverseWeight ReverseWeight; + + ReverseWeight rw1 = w1.Reverse(); + ReverseWeight rw2 = w2.Reverse(); + + CHECK(rw1.Reverse() == w1); + CHECK(Plus(w1, w2).Reverse() == Plus(rw1, rw2)); + CHECK(Times(w1, w2).Reverse() == Times(rw2, rw1)); + } + + // Tests == is an equivalence relation. + void TestEquality(Weight w1, Weight w2, Weight w3) { + // Checks reflexivity. + CHECK(w1 == w1); + + // Checks symmetry. + CHECK((w1 == w2) == (w2 == w1)); + + // Checks transitivity. + if (w1 == w2 && w2 == w3) CHECK(w1 == w3); + } + + // Tests binary serialization and textual I/O. + void TestIO(Weight w) { + // Tests binary I/O + { + std::ostringstream os; + w.Write(os); + os.flush(); + std::istringstream is(os.str()); + Weight v; + v.Read(is); + CHECK_EQ(w, v); + } + + // Tests textual I/O. + { + std::ostringstream os; + os << w; + std::istringstream is(os.str()); + Weight v(Weight::One()); + is >> v; + CHECK(ApproxEqual(w, v)); + } + } + + // Tests copy constructor and assignment operator + void TestCopy(Weight w) { + Weight x = w; + CHECK(w == x); + + x = Weight(w); + CHECK(w == x); + + x.operator=(x); + CHECK(w == x); + } + + // Generates weights used in testing. + WeightGenerator weight_generator_; +}; + +} // namespace fst + +#endif // FST_TEST_WEIGHT_TESTER_H_ diff --git a/projects/llm_framework/include/fst/topsort.h b/projects/llm_framework/include/fst/topsort.h new file mode 100644 index 00000000..cae3e154 --- /dev/null +++ b/projects/llm_framework/include/fst/topsort.h @@ -0,0 +1,95 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Topological sort of FSTs. + +#ifndef FST_TOPSORT_H_ +#define FST_TOPSORT_H_ + +#include +#include + + +#include +#include +#include + + +namespace fst { + +// DFS visitor class to return topological ordering. +template +class TopOrderVisitor { + public: + using StateId = typename Arc::StateId; + + // If acyclic, order[i] gives the topological position of StateId i; + // otherwise it is unchanged. acyclic_ will be true iff the FST has no + // cycles. The caller retains ownership of the state order vector. + TopOrderVisitor(std::vector *order, bool *acyclic) + : order_(order), acyclic_(acyclic) {} + + void InitVisit(const Fst &fst) { + finish_.reset(new std::vector()); + *acyclic_ = true; + } + + constexpr bool InitState(StateId, StateId) const { return true; } + + constexpr bool TreeArc(StateId, const Arc &) const { return true; } + + bool BackArc(StateId, const Arc &) { return (*acyclic_ = false); } + + constexpr bool ForwardOrCrossArc(StateId, const Arc &) const { return true; } + + void FinishState(StateId s, StateId, const Arc *) { finish_->push_back(s); } + + void FinishVisit() { + if (*acyclic_) { + order_->clear(); + for (StateId s = 0; s < finish_->size(); ++s) { + order_->push_back(kNoStateId); + } + for (StateId s = 0; s < finish_->size(); ++s) { + (*order_)[(*finish_)[finish_->size() - s - 1]] = s; + } + } + finish_.reset(); + } + + private: + std::vector *order_; + bool *acyclic_; + // States in finish-time order. + std::unique_ptr> finish_; +}; + +// Topologically sorts its input if acyclic, modifying it. Otherwise, the input +// is unchanged. When sorted, all transitions are from lower to higher state +// IDs. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(V + E) +// +// where V is the number of states and E is the number of arcs. +template +bool TopSort(MutableFst *fst) { + std::vector order; + bool acyclic; + TopOrderVisitor top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor); + if (acyclic) { + StateSort(fst, order); + fst->SetProperties(kAcyclic | kInitialAcyclic | kTopSorted, + kAcyclic | kInitialAcyclic | kTopSorted); + } else { + fst->SetProperties(kCyclic | kNotTopSorted, kCyclic | kNotTopSorted); + } + return acyclic; +} + +} // namespace fst + +#endif // FST_TOPSORT_H_ diff --git a/projects/llm_framework/include/fst/tuple-weight.h b/projects/llm_framework/include/fst/tuple-weight.h new file mode 100644 index 00000000..c8ecab76 --- /dev/null +++ b/projects/llm_framework/include/fst/tuple-weight.h @@ -0,0 +1,163 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Tuple weight set operation definitions. + +#ifndef FST_TUPLE_WEIGHT_H_ +#define FST_TUPLE_WEIGHT_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +// n-tuple weight, element of the n-th Cartesian power of W. +template +class TupleWeight { + public: + using ReverseWeight = TupleWeight; + + using Weight = W; + using Index = size_t; + + template + TupleWeight(Iterator begin, Iterator end) { + std::copy(begin, end, values_.begin()); + } + + explicit TupleWeight(const W &weight = W::Zero()) { values_.fill(weight); } + + // Initialize component `index` to `weight`; initialize all other components + // to `default_weight` + TupleWeight(Index index, const W &weight, const W &default_weight) + : TupleWeight(default_weight) { + values_[index] = weight; + } + + static const TupleWeight &Zero() { + static const TupleWeight zero(W::Zero()); + return zero; + } + + static const TupleWeight &One() { + static const TupleWeight one(W::One()); + return one; + } + + static const TupleWeight &NoWeight() { + static const TupleWeight no_weight(W::NoWeight()); + return no_weight; + } + + constexpr static size_t Length() { return n; } + + std::istream &Read(std::istream &istrm) { + for (size_t i = 0; i < n; ++i) values_[i].Read(istrm); + return istrm; + } + + std::ostream &Write(std::ostream &ostrm) const { + for (size_t i = 0; i < n; ++i) values_[i].Write(ostrm); + return ostrm; + } + + bool Member() const { + return std::all_of(values_.begin(), values_.end(), + std::mem_fn(&W::Member)); + } + + size_t Hash() const { + uint64 hash = 0; + for (size_t i = 0; i < n; ++i) hash = 5 * hash + values_[i].Hash(); + return size_t(hash); + } + + TupleWeight Quantize(float delta = kDelta) const { + TupleWeight weight; + for (size_t i = 0; i < n; ++i) { + weight.values_[i] = values_[i].Quantize(delta); + } + return weight; + } + + ReverseWeight Reverse() const { + TupleWeight w; + for (size_t i = 0; i < n; ++i) w.values_[i] = values_[i].Reverse(); + return w; + } + + const W &Value(size_t i) const { return values_[i]; } + + void SetValue(size_t i, const W &w) { values_[i] = w; } + + private: + std::array values_; +}; + +template +inline bool operator==(const TupleWeight &w1, + const TupleWeight &w2) { + for (size_t i = 0; i < n; ++i) { + if (w1.Value(i) != w2.Value(i)) return false; + } + return true; +} + +template +inline bool operator!=(const TupleWeight &w1, + const TupleWeight &w2) { + for (size_t i = 0; i < n; ++i) { + if (w1.Value(i) != w2.Value(i)) return true; + } + return false; +} + +template +inline bool ApproxEqual(const TupleWeight &w1, + const TupleWeight &w2, float delta = kDelta) { + for (size_t i = 0; i < n; ++i) { + if (!ApproxEqual(w1.Value(i), w2.Value(i), delta)) return false; + } + return true; +} + +template +inline std::ostream &operator<<(std::ostream &strm, + const TupleWeight &w) { + CompositeWeightWriter writer(strm); + writer.WriteBegin(); + for (size_t i = 0; i < n; ++i) writer.WriteElement(w.Value(i)); + writer.WriteEnd(); + return strm; +} + +template +inline std::istream &operator>>(std::istream &strm, TupleWeight &w) { + CompositeWeightReader reader(strm); + reader.ReadBegin(); + W v; + // Reads first n-1 elements. + static_assert(n > 0, "Size must be positive."); + for (size_t i = 0; i < n - 1; ++i) { + reader.ReadElement(&v); + w.SetValue(i, v); + } + // Reads n-th element. + reader.ReadElement(&v, true); + w.SetValue(n - 1, v); + reader.ReadEnd(); + return strm; +} + +} // namespace fst + +#endif // FST_TUPLE_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/types.h b/projects/llm_framework/include/fst/types.h new file mode 100644 index 00000000..9c0b7998 --- /dev/null +++ b/projects/llm_framework/include/fst/types.h @@ -0,0 +1,41 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Various type definitions (mostly for Google compatibility). + +#include // For std::ptrdiff_t. +#include // for ssize_t. +#include // for ?int*_t. + +#ifndef FST_LIB_TYPES_H_ +#define FST_LIB_TYPES_H_ + +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; + +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; + +#ifdef _MSC_VER +// Not really Windows-specific: they should have used ptrdiff_t in the first +// place. But on Windows there has never been ssize_t. +using ssize_t = std::ptrdiff_t; +#endif // _MSC_VER + +#endif // FST_LIB_TYPES_H_ diff --git a/projects/llm_framework/include/fst/union-find.h b/projects/llm_framework/include/fst/union-find.h new file mode 100644 index 00000000..b5c7f10b --- /dev/null +++ b/projects/llm_framework/include/fst/union-find.h @@ -0,0 +1,84 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Union-find algorithm for dense sets of non-negative integers, implemented +// using disjoint tree forests with rank heuristics and path compression. + +#ifndef FST_UNION_FIND_H_ +#define FST_UNION_FIND_H_ + +#include +#include + +namespace fst { + +// Union-Find algorithm for dense sets of non-negative integers. +template +class UnionFind { + public: + // Creates a disjoint set forest for the range [0; max); 'fail' is a value + // indicating that an element hasn't been initialized using MakeSet(...). + // The upper bound of the range can be reset (increased) using MakeSet(...). + UnionFind(T max, T fail) : parent_(max, fail), rank_(max), fail_(fail) {} + + // Finds the representative of the set 'item' belongs to, performing path + // compression if necessary. + T FindSet(T item) { + if (item >= parent_.size() || item == fail_ || parent_[item] == fail_) { + return fail_; + } + auto *p = &parent_[item]; + for (; *p != item; item = *p, p = &parent_[item]) exec_stack_.push(p); + for (; !exec_stack_.empty(); exec_stack_.pop()) *exec_stack_.top() = *p; + return *p; + } + + // Creates the (destructive) union of the sets x and y belong to. + void Union(T x, T y) { Link(FindSet(x), FindSet(y)); } + + // Initialization of an element: creates a singleton set containing 'item'. + // The range [0; max) is reset if item >= max. + T MakeSet(T item) { + if (item >= parent_.size()) { + // New value in parent_ should be initialized to fail_. + const auto nitem = item > 0 ? 2 * item : 2; + parent_.resize(nitem, fail_); + rank_.resize(nitem); + } + parent_[item] = item; + return item; + } + + // Initialization of all elements starting from 0 to max - 1 to distinct sets. + void MakeAllSet(T max) { + parent_.resize(max); + for (T item = 0; item < max; ++item) parent_[item] = item; + } + + private: + // Links trees rooted in 'x' and 'y'. + void Link(T x, T y) { + if (x == y) return; + if (rank_[x] > rank_[y]) { + parent_[y] = x; + } else { + parent_[x] = y; + if (rank_[x] == rank_[y]) { + ++rank_[y]; + } + } + } + + UnionFind(const UnionFind &) = delete; + + UnionFind &operator=(const UnionFind &) = delete; + + std::vector parent_; // Parent nodes. + std::vector rank_; // Rank of an element = min. depth in tree. + T fail_; // Value indicating lookup failure. + std::stack exec_stack_; // Used for path compression. +}; + +} // namespace fst + +#endif // FST_UNION_FIND_H_ diff --git a/projects/llm_framework/include/fst/union-weight.h b/projects/llm_framework/include/fst/union-weight.h new file mode 100644 index 00000000..bb2dea96 --- /dev/null +++ b/projects/llm_framework/include/fst/union-weight.h @@ -0,0 +1,505 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Union weight set and associated semiring operation definitions. +// +// TODO(riley): add in normalizer functor + +#ifndef FST_UNION_WEIGHT_H_ +#define FST_UNION_WEIGHT_H_ + +#include + +#include +#include +#include +#include +#include + +#include + + +namespace fst { + +// Example UnionWeightOptions for UnionWeight template below. The Merge +// operation is used to collapse elements of the set and the Compare function +// to efficiently implement the merge. In the simplest case, merge would just +// apply with equality of set elements so the result is a set (and not a +// multiset). More generally, this can be used to maintain the multiplicity or +// other such weight associated with the set elements (cf. Gallic weights). + +// template +// struct UnionWeightOptions { +// // Comparison function C is a total order on W that is monotonic w.r.t. to +// // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is +// // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a). +// // +// // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where +// // ~ is an equivalence relation on W. Also we require a ~ b iff +// // a.Reverse() ~ b.Reverse(). +// using Compare = NaturalLess; +// +// // How to combine two weights if a ~ b as above. For all a, b: a ~ b => +// // merge(a, b) ~ a, Merge must define a semiring endomorphism from the +// // unmerged weight sets to the merged weight sets. +// struct Merge { +// W operator()(const W &w1, const W &w2) const { return w1; } +// }; +// +// // For ReverseWeight. +// using ReverseOptions = UnionWeightOptions; +// }; + +template +class UnionWeight; + +template +class UnionWeightIterator; + +template +class UnionWeightReverseIterator; + +template +bool operator==(const UnionWeight &, const UnionWeight &); + +// Semiring that uses Times() and One() from W and union and the empty set +// for Plus() and Zero(), respectively. Template argument O specifies the union +// weight options as above. +template +class UnionWeight { + public: + using Weight = W; + using Compare = typename O::Compare; + using Merge = typename O::Merge; + + using ReverseWeight = + UnionWeight; + + friend class UnionWeightIterator; + friend class UnionWeightReverseIterator; + friend bool operator== + <>(const UnionWeight &, const UnionWeight &); + + // Sets represented as first_ weight + rest_ weights. Uses first_ as + // NoWeight() to indicate the union weight Zero() ask the empty set. Uses + // rest_ containing NoWeight() to indicate the union weight NoWeight(). + UnionWeight() : first_(W::NoWeight()) {} + + explicit UnionWeight(W weight) : first_(weight) { + if (weight == W::NoWeight()) rest_.push_back(weight); + } + + static const UnionWeight &Zero() { + static const UnionWeight zero(W::NoWeight()); + return zero; + } + + static const UnionWeight &One() { + static const UnionWeight one(W::One()); + return one; + } + + static const UnionWeight &NoWeight() { + static const UnionWeight no_weight(W::Zero(), W::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = new string(W::Type() + "_union"); + return *type; + } + + static constexpr uint64 Properties() { + return W::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + UnionWeight Quantize(float delta = kDelta) const; + + ReverseWeight Reverse() const; + + // These operations combined with the UnionWeightIterator and + // UnionWeightReverseIterator provide the access and mutation of the union + // weight internal elements. + + // Common initializer among constructors; clears existing UnionWeight. + void Clear() { + first_ = W::NoWeight(); + rest_.clear(); + } + + size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; } + + const W &Back() const { return rest_.empty() ? first_ : rest_.back(); } + + // When srt is true, assumes elements added sorted w.r.t Compare and merging + // of weights performed as needed. Otherwise, just ensures first_ is the + // least element wrt Compare. + void PushBack(W weight, bool srt); + + // Sorts the elements of the set. Assumes that first_, if present, is the + // least element. + void Sort() { rest_.sort(comp_); } + + private: + W &Back() { + if (rest_.empty()) { + return first_; + } else { + return rest_.back(); + } + } + + UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {} + + W first_; // First weight in set. + std::list rest_; // Remaining weights in set. + Compare comp_; + Merge merge_; +}; + +template +void UnionWeight::PushBack(W weight, bool srt) { + if (!weight.Member()) { + rest_.push_back(std::move(weight)); + } else if (!first_.Member()) { + first_ = std::move(weight); + } else if (srt) { + auto &back = Back(); + if (comp_(back, weight)) { + rest_.push_back(std::move(weight)); + } else { + back = merge_(back, std::move(weight)); + } + } else { + if (comp_(first_, weight)) { + rest_.push_back(std::move(weight)); + } else { + rest_.push_back(first_); + first_ = std::move(weight); + } + } +} + +// Traverses union weight in the forward direction. +template +class UnionWeightIterator { + public: + explicit UnionWeightIterator(const UnionWeight &weight) + : first_(weight.first_), + rest_(weight.rest_), + init_(true), + it_(rest_.begin()) {} + + bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); } + + const W &Value() const { return init_ ? first_ : *it_; } + + void Next() { + if (init_) { + init_ = false; + } else { + ++it_; + } + } + + void Reset() { + init_ = true; + it_ = rest_.begin(); + } + + private: + const W &first_; + const std::list &rest_; + bool init_; // in the initialized state? + typename std::list::const_iterator it_; +}; + +// Traverses union weight in backward direction. +template +class UnionWeightReverseIterator { + public: + explicit UnionWeightReverseIterator(const UnionWeight &weight) + : first_(weight.first_), + rest_(weight.rest_), + fin_(!first_.Member()), + it_(rest_.rbegin()) {} + + bool Done() const { return fin_; } + + const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; } + + void Next() { + if (it_ == rest_.rend()) { + fin_ = true; + } else { + ++it_; + } + } + + void Reset() { + fin_ = !first_.Member(); + it_ = rest_.rbegin(); + } + + private: + const L &first_; + const std::list &rest_; + bool fin_; // in the final state? + typename std::list::const_reverse_iterator it_; +}; + +// UnionWeight member functions follow that require UnionWeightIterator. +template +inline std::istream &UnionWeight::Read(std::istream &istrm) { + Clear(); + int32 size; + ReadType(istrm, &size); + for (int i = 0; i < size; ++i) { + W weight; + ReadType(istrm, &weight); + PushBack(weight, true); + } + return istrm; +} + +template +inline std::ostream &UnionWeight::Write(std::ostream &ostrm) const { + const int32 size = Size(); + WriteType(ostrm, size); + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + WriteType(ostrm, it.Value()); + } + return ostrm; +} + +template +inline bool UnionWeight::Member() const { + if (Size() <= 1) return true; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + if (!it.Value().Member()) return false; + } + return true; +} + +template +inline UnionWeight UnionWeight::Quantize(float delta) const { + UnionWeight weight; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + weight.PushBack(it.Value().Quantize(delta), true); + } + return weight; +} + +template +inline typename UnionWeight::ReverseWeight UnionWeight::Reverse() + const { + ReverseWeight weight; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + weight.PushBack(it.Value().Reverse(), false); + } + weight.Sort(); + return weight; +} + +template +inline size_t UnionWeight::Hash() const { + size_t h = 0; + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + h = h << lshift ^ h >> rshift ^ it.Value().Hash(); + } + return h; +} + +// Requires union weight has been canonicalized. +template +inline bool operator==(const UnionWeight &w1, + const UnionWeight &w2) { + if (w1.Size() != w2.Size()) return false; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + for (; !it1.Done(); it1.Next(), it2.Next()) { + if (it1.Value() != it2.Value()) return false; + } + return true; +} + +// Requires union weight has been canonicalized. +template +inline bool operator!=(const UnionWeight &w1, + const UnionWeight &w2) { + return !(w1 == w2); +} + +// Requires union weight has been canonicalized. +template +inline bool ApproxEqual(const UnionWeight &w1, + const UnionWeight &w2, float delta = kDelta) { + if (w1.Size() != w2.Size()) return false; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + for (; !it1.Done(); it1.Next(), it2.Next()) { + if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false; + } + return true; +} + +template +inline std::ostream &operator<<(std::ostream &ostrm, + const UnionWeight &weight) { + UnionWeightIterator it(weight); + if (it.Done()) { + return ostrm << "EmptySet"; + } else if (!weight.Member()) { + return ostrm << "BadSet"; + } else { + CompositeWeightWriter writer(ostrm); + writer.WriteBegin(); + for (; !it.Done(); it.Next()) writer.WriteElement(it.Value()); + writer.WriteEnd(); + } + return ostrm; +} + +template +inline std::istream &operator>>(std::istream &istrm, + UnionWeight &weight) { + string s; + istrm >> s; + if (s == "EmptySet") { + weight = UnionWeight::Zero(); + } else if (s == "BadSet") { + weight = UnionWeight::NoWeight(); + } else { + weight = UnionWeight::Zero(); + std::istringstream sstrm(s); + CompositeWeightReader reader(sstrm); + reader.ReadBegin(); + bool more = true; + while (more) { + W v; + more = reader.ReadElement(&v); + weight.PushBack(v, true); + } + reader.ReadEnd(); + } + return istrm; +} + +template +inline UnionWeight Plus(const UnionWeight &w1, + const UnionWeight &w2) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero()) return w2; + if (w2 == UnionWeight::Zero()) return w1; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + UnionWeight sum; + typename O::Compare comp; + while (!it1.Done() && !it2.Done()) { + const auto v1 = it1.Value(); + const auto v2 = it2.Value(); + if (comp(v1, v2)) { + sum.PushBack(v1, true); + it1.Next(); + } else { + sum.PushBack(v2, true); + it2.Next(); + } + } + for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true); + for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true); + return sum; +} + +template +inline UnionWeight Times(const UnionWeight &w1, + const UnionWeight &w2) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero() || w2 == UnionWeight::Zero()) { + return UnionWeight::Zero(); + } + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + UnionWeight prod1; + for (; !it1.Done(); it1.Next()) { + UnionWeight prod2; + for (; !it2.Done(); it2.Next()) { + prod2.PushBack(Times(it1.Value(), it2.Value()), true); + } + prod1 = Plus(prod1, prod2); + it2.Reset(); + } + return prod1; +} + +template +inline UnionWeight Divide(const UnionWeight &w1, + const UnionWeight &w2, DivideType typ) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero() || w2 == UnionWeight::Zero()) { + return UnionWeight::Zero(); + } + UnionWeightIterator it1(w1); + UnionWeightReverseIterator it2(w2); + UnionWeight quot; + if (w1.Size() == 1) { + for (; !it2.Done(); it2.Next()) { + quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); + } + } else if (w2.Size() == 1) { + for (; !it1.Done(); it1.Next()) { + quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); + } + } else { + quot = UnionWeight::NoWeight(); + } + return quot; +} + +// This function object generates weights over the union of weights for the +// underlying generators for the template weight types. This is intended +// primarily for testing. +template +class WeightGenerate> { + public: + using Weight = UnionWeight; + using Generate = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true, + size_t num_random_weights = kNumRandomWeights) + : generate_(false), allow_zero_(allow_zero), + num_random_weights_(num_random_weights) {} + + Weight operator()() const { + const int n = rand() % (num_random_weights_ + 1); // NOLINT + if (allow_zero_ && n == num_random_weights_) { + return Weight::Zero(); + } else if (n % 2 == 0) { + return Weight(generate_()); + } else { + return Plus(Weight(generate_()), Weight(generate_())); + } + } + + private: + Generate generate_; + // Permits Zero() and zero divisors. + bool allow_zero_; + // The number of alternative random weights. + const size_t num_random_weights_; +}; + +} // namespace fst + +#endif // FST_UNION_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/union.h b/projects/llm_framework/include/fst/union.h new file mode 100644 index 00000000..257099b8 --- /dev/null +++ b/projects/llm_framework/include/fst/union.h @@ -0,0 +1,157 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the union of two FSTs. + +#ifndef FST_UNION_H_ +#define FST_UNION_H_ + +#include +#include +#include + +#include +#include + + +namespace fst { + +// Computes the union (sum) of two FSTs. This version writes the union to an +// output MutableFst. If A transduces string x to y with weight a and B +// transduces string w to v with weight b, then their union transduces x to y +// with weight a and w to v with weight b. +// +// Complexity: +// +// Time: (V_2 + E_2) +// Space: O(V_2 + E_2) +// +// where Vi is the number of states, and Ei is the number of arcs, in the ith +// FST. +template +void Union(MutableFst *fst1, const Fst &fst2) { + using Weight = typename Arc::Weight; + // Checks for symbol table compatibility. + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Union: Input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + fst1->SetProperties(kError, kError); + return; + } + const auto numstates1 = fst1->NumStates(); + const bool initial_acyclic1 = fst1->Properties(kInitialAcyclic, true); + const auto props1 = fst1->Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + const auto start2 = fst2.Start(); + if (start2 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + if (fst2.Properties(kExpanded, false)) { + fst1->ReserveStates(numstates1 + CountStates(fst2) + + (initial_acyclic1 ? 0 : 1)); + } + for (StateIterator> siter(fst2); !siter.Done(); siter.Next()) { + const auto s1 = fst1->AddState(); + const auto s2 = siter.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator> aiter(fst2, s2); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); // Copy intended. + arc.nextstate += numstates1; + fst1->AddArc(s1, std::move(arc)); + } + } + const auto start1 = fst1->Start(); + if (start1 == kNoStateId) { + fst1->SetStart(start2); + fst1->SetProperties(props2, kCopyProperties); + return; + } + if (initial_acyclic1) { + fst1->AddArc(start1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } else { + const auto nstart1 = fst1->AddState(); + fst1->SetStart(nstart1); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start1)); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } + fst1->SetProperties(UnionProperties(props1, props2), kFstProperties); +} + +// Computes the union of two FSTs, modifying the RationalFst argument. +template +void Union(RationalFst *fst1, const Fst &fst2) { + fst1->GetMutableImpl()->AddUnion(fst2); +} + +using UnionFstOptions = RationalFstOptions; + +// Computes the union (sum) of two FSTs. This version is a delayed FST. If A +// transduces string x to y with weight a and B transduces string w to v with +// weight b, then their union transduces x to y with weight a and w to v with +// weight b. +// +// Complexity: +// +// Time: O(v_1 + e_1 + v_2 + e_2) +// Space: O(v_1 + v_2) +// +// where vi is the number of states visited, and ei is the number of arcs +// visited, in the ith FST. Constant time and space to visit an input state or +// arc is assumed and exclusive of caching. +template +class UnionFst : public RationalFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + UnionFst(const Fst &fst1, const Fst &fst2) { + GetMutableImpl()->InitUnion(fst1, fst2); + } + + UnionFst(const Fst &fst1, const Fst &fst2, + const UnionFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitUnion(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + UnionFst(const UnionFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Gets a copy of this UnionFst. See Fst<>::Copy() for further doc. + UnionFst *Copy(bool safe = false) const override { + return new UnionFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for UnionFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const UnionFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for UnionFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const UnionFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +using StdUnionFst = UnionFst; + +} // namespace fst + +#endif // FST_UNION_H_ diff --git a/projects/llm_framework/include/fst/util.h b/projects/llm_framework/include/fst/util.h new file mode 100644 index 00000000..c7520213 --- /dev/null +++ b/projects/llm_framework/include/fst/util.h @@ -0,0 +1,400 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST utility inline definitions. + +#ifndef FST_UTIL_H_ +#define FST_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + + +// Utility for error handling. + +DECLARE_bool(fst_error_fatal); + +#define FSTERROR() \ + (FLAGS_fst_error_fatal ? LOG(FATAL) : LOG(ERROR)) + +namespace fst { + +// Utility for type I/O. + +// Reads types from an input stream. + +// Generic case. +template ::value, T>::type* = nullptr> +inline std::istream &ReadType(std::istream &strm, T *t) { + return t->Read(strm); +} + +// Numeric (boolean, integral, floating-point) case. +template ::value, T>::type* = nullptr> +inline std::istream &ReadType(std::istream &strm, T *t) { + return strm.read(reinterpret_cast(t), sizeof(T)); \ +} + +// String case. +inline std::istream &ReadType(std::istream &strm, string *s) { // NOLINT + s->clear(); + int32 ns = 0; + strm.read(reinterpret_cast(&ns), sizeof(ns)); + for (int32 i = 0; i < ns; ++i) { + char c; + strm.read(&c, 1); + *s += c; + } + return strm; +} + +// Pair case. +template +inline std::istream &ReadType(std::istream &strm, std::pair *p) { + ReadType(strm, &p->first); + ReadType(strm, &p->second); + return strm; +} + +template +inline std::istream &ReadType(std::istream &strm, std::pair *p) { + ReadType(strm, const_cast(&p->first)); + ReadType(strm, &p->second); + return strm; +} + +namespace internal { +template +std::istream &ReadContainerType(std::istream &strm, C *c, ReserveFn reserve) { + c->clear(); + int64 n = 0; + ReadType(strm, &n); + reserve(c, n); + auto insert = std::inserter(*c, c->begin()); + for (int64 i = 0; i < n; ++i) { + typename C::value_type value; + ReadType(strm, &value); + *insert = value; + } + return strm; +} +} // namespace internal + +template +std::istream &ReadType(std::istream &strm, std::vector *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +template +std::istream &ReadType(std::istream &strm, std::list *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::set *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::map *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::unordered_set *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +template +std::istream &ReadType(std::istream &strm, std::unordered_map *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +// Writes types to an output stream. + +// Generic case. +template ::value, T>::type* = nullptr> +inline std::ostream &WriteType(std::ostream &strm, const T t) { + t.Write(strm); + return strm; +} + +// Numeric (boolean, integral, floating-point) case. +template ::value, T>::type* = nullptr> +inline std::ostream &WriteType(std::ostream &strm, const T t) { + return strm.write(reinterpret_cast(&t), sizeof(T)); +} + +// String case. +inline std::ostream &WriteType(std::ostream &strm, const string &s) { // NOLINT + int32 ns = s.size(); + strm.write(reinterpret_cast(&ns), sizeof(ns)); + return strm.write(s.data(), ns); +} + +// Pair case. +template +inline std::ostream &WriteType(std::ostream &strm, + const std::pair &p) { // NOLINT + WriteType(strm, p.first); + WriteType(strm, p.second); + return strm; +} + +namespace internal { +template +std::ostream &WriteContainer(std::ostream &strm, const C &c) { + const int64 n = c.size(); + WriteType(strm, n); + for (const auto &e : c) { + WriteType(strm, e); + } + return strm; +} +} // namespace internal + +template +std::ostream &WriteType(std::ostream &strm, const std::vector &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::list &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::set &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::map &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::unordered_set &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::unordered_map &c) { + return internal::WriteContainer(strm, c); +} + +// Utilities for converting between int64 or Weight and string. + +int64 StrToInt64(const string &s, const string &src, size_t nline, + bool allow_negative, bool *error = nullptr); + +template +Weight StrToWeight(const string &s, const string &src, size_t nline) { + Weight w; + std::istringstream strm(s); + strm >> w; + if (!strm) { + FSTERROR() << "StrToWeight: Bad weight = \"" << s << "\", source = " << src + << ", line = " << nline; + return Weight::NoWeight(); + } + return w; +} + +template +void WeightToStr(Weight w, string *s) { + std::ostringstream strm; + strm.precision(9); + strm << w; + s->append(strm.str().data(), strm.str().size()); +} + +// Utilities for reading/writing integer pairs (typically labels) + +// Modifies line using a vector of pointers to a buffer beginning with line. +void SplitString(char *line, const char *delim, std::vector *vec, + bool omit_empty_strings); + +template +bool ReadIntPairs(const string &filename, std::vector> *pairs, + bool allow_negative = false) { + std::ifstream strm(filename, std::ios_base::in); + if (!strm) { + LOG(ERROR) << "ReadIntPairs: Can't open file: " << filename; + return false; + } + const int kLineLen = 8096; + char line[kLineLen]; + size_t nline = 0; + pairs->clear(); + while (strm.getline(line, kLineLen)) { + ++nline; + std::vector col; + SplitString(line, "\n\t ", &col, true); + // empty line or comment? + if (col.empty() || col[0][0] == '\0' || col[0][0] == '#') continue; + if (col.size() != 2) { + LOG(ERROR) << "ReadIntPairs: Bad number of columns, " + << "file = " << filename << ", line = " << nline; + return false; + } + bool err; + I i1 = StrToInt64(col[0], filename, nline, allow_negative, &err); + if (err) return false; + I i2 = StrToInt64(col[1], filename, nline, allow_negative, &err); + if (err) return false; + pairs->push_back(std::make_pair(i1, i2)); + } + return true; +} + +template +bool WriteIntPairs(const string &filename, + const std::vector> &pairs) { + std::ostream *strm = &std::cout; + if (!filename.empty()) { + strm = new std::ofstream(filename); + if (!*strm) { + LOG(ERROR) << "WriteIntPairs: Can't open file: " << filename; + return false; + } + } + for (ssize_t n = 0; n < pairs.size(); ++n) { + *strm << pairs[n].first << "\t" << pairs[n].second << "\n"; + } + if (!*strm) { + LOG(ERROR) << "WriteIntPairs: Write failed: " + << (filename.empty() ? "standard output" : filename); + return false; + } + if (strm != &std::cout) delete strm; + return true; +} + +// Utilities for reading/writing label pairs. + +template +bool ReadLabelPairs(const string &filename, + std::vector> *pairs, + bool allow_negative = false) { + return ReadIntPairs(filename, pairs, allow_negative); +} + +template +bool WriteLabelPairs(const string &filename, + const std::vector> &pairs) { + return WriteIntPairs(filename, pairs); +} + +// Utilities for converting a type name to a legal C symbol. + +void ConvertToLegalCSymbol(string *s); + +// Utilities for stream I/O. + +bool AlignInput(std::istream &strm); +bool AlignOutput(std::ostream &strm); + +// An associative container for which testing membership is faster than an STL +// set if members are restricted to an interval that excludes most non-members. +// A Key must have ==, !=, and < operators defined. Element NoKey should be a +// key that marks an uninitialized key and is otherwise unused. Find() returns +// an STL const_iterator to the match found, otherwise it equals End(). +template +class CompactSet { + public: + using const_iterator = typename std::set::const_iterator; + + CompactSet() : min_key_(NoKey), max_key_(NoKey) {} + + CompactSet(const CompactSet &compact_set) + : set_(compact_set.set_), + min_key_(compact_set.min_key_), + max_key_(compact_set.max_key_) {} + + void Insert(Key key) { + set_.insert(key); + if (min_key_ == NoKey || key < min_key_) min_key_ = key; + if (max_key_ == NoKey || max_key_ < key) max_key_ = key; + } + + void Erase(Key key) { + set_.erase(key); + if (set_.empty()) { + min_key_ = max_key_ = NoKey; + } else if (key == min_key_) { + ++min_key_; + } else if (key == max_key_) { + --max_key_; + } + } + + void Clear() { + set_.clear(); + min_key_ = max_key_ = NoKey; + } + + const_iterator Find(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return set_.end(); + } else { + return set_.find(key); + } + } + + bool Member(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return false; // out of range + } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) { + return true; // dense range + } else { + return set_.count(key); + } + } + + const_iterator Begin() const { return set_.begin(); } + + const_iterator End() const { return set_.end(); } + + // All stored keys are greater than or equal to this value. + Key LowerBound() const { return min_key_; } + + // All stored keys are less than or equal to this value. + Key UpperBound() const { return max_key_; } + + private: + std::set set_; + Key min_key_; + Key max_key_; + + void operator=(const CompactSet &) = delete; +}; + +} // namespace fst + +#endif // FST_UTIL_H_ diff --git a/projects/llm_framework/include/fst/vector-fst.h b/projects/llm_framework/include/fst/vector-fst.h new file mode 100644 index 00000000..7514bc55 --- /dev/null +++ b/projects/llm_framework/include/fst/vector-fst.h @@ -0,0 +1,796 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Simple concrete, mutable FST whose states and arcs are stored in STL vectors. + +#ifndef FST_VECTOR_FST_H_ +#define FST_VECTOR_FST_H_ + +#include +#include +#include + +#include + +#include // For optional argument declarations +#include +#include + + +namespace fst { + +template +class VectorFst; + +template +void Cast(const F &, G *); + +// Arcs (of type A) implemented by an STL vector per state. M specifies Arc +// allocator (default declared in fst-decl.h). +template */> +class VectorState { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ArcAllocator = M; + using StateAllocator = + typename ArcAllocator::template rebind>::other; + + // Provide STL allocator for arcs. + explicit VectorState(const ArcAllocator &alloc) + : final_(Weight::Zero()), niepsilons_(0), noepsilons_(0), arcs_(alloc) {} + + VectorState(const VectorState &state, const ArcAllocator &alloc) + : final_(state.Final()), + niepsilons_(state.NumInputEpsilons()), + noepsilons_(state.NumOutputEpsilons()), + arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {} + + void Reset() { + final_ = Weight::Zero(); + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return arcs_.size(); } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + Arc *MutableArcs() { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + void ReserveArcs(size_t n) { arcs_.reserve(n); } + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void SetNumInputEpsilons(size_t n) { niepsilons_ = n; } + + void SetNumOutputEpsilons(size_t n) { noepsilons_ = n; } + + void AddArc(const Arc &arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(arc); + } + + void AddArc(Arc &&arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(std::move(arc)); + } + + template + void EmplaceArc(T&&... ctor_args) { + arcs_.emplace_back(std::forward(ctor_args)...); + IncrementNumEpsilons(arcs_.back()); + } + + void SetArc(const Arc &arc, size_t n) { + if (arcs_[n].ilabel == 0) --niepsilons_; + if (arcs_[n].olabel == 0) --noepsilons_; + IncrementNumEpsilons(arc); + arcs_[n] = arc; + } + + void DeleteArcs() { + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + void DeleteArcs(size_t n) { + for (size_t i = 0; i < n; ++i) { + if (arcs_.back().ilabel == 0) --niepsilons_; + if (arcs_.back().olabel == 0) --noepsilons_; + arcs_.pop_back(); + } + } + + // For state class allocation. + void *operator new(size_t size, StateAllocator *alloc) { + return alloc->allocate(1); + } + + // For state destruction and memory freeing. + static void Destroy(VectorState *state, StateAllocator *alloc) { + if (state) { + state->~VectorState(); + alloc->deallocate(state, 1); + } + } + + private: + // Update the number of epsilons as a result of having added an arc. + void IncrementNumEpsilons(const Arc &arc) { + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + } + + Weight final_; // Final weight. + size_t niepsilons_; // # of input epsilons + size_t noepsilons_; // # of output epsilons + std::vector arcs_; // Arc container. +}; + +namespace internal { + +// States are implemented by STL vectors, templated on the +// State definition. This does not manage the Fst properties. +template +class VectorFstBaseImpl : public FstImpl { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + VectorFstBaseImpl() : start_(kNoStateId) {} + + ~VectorFstBaseImpl() override { + for (size_t s = 0; s < states_.size(); ++s) { + State::Destroy(states_[s], &state_alloc_); + } + } + + // Copying is not permitted. + VectorFstBaseImpl(const VectorFstBaseImpl &) = delete; + VectorFstBaseImpl &operator=(const VectorFstBaseImpl &) = delete; + + // Moving is permitted. + VectorFstBaseImpl(VectorFstBaseImpl &&impl) noexcept + : FstImpl(), + states_(std::move(impl.states_)), + start_(impl.start_) { + impl.states_.clear(); + impl.start_ = kNoStateId; + } + + VectorFstBaseImpl &operator=(VectorFstBaseImpl &&impl) noexcept { + states_ = std::move(impl.states_); + start_ = impl.start_; + impl.states_.clear(); + impl.start_ = kNoStateId; + return *this; + } + + StateId Start() const { return start_; } + + Weight Final(StateId state) const { return states_[state]->Final(); } + + StateId NumStates() const { return states_.size(); } + + size_t NumArcs(StateId state) const { return states_[state]->NumArcs(); } + + size_t NumInputEpsilons(StateId state) const { + return GetState(state)->NumInputEpsilons(); + } + + size_t NumOutputEpsilons(StateId state) const { + return GetState(state)->NumOutputEpsilons(); + } + + void SetStart(StateId state) { start_ = state; } + + void SetFinal(StateId state, Weight weight) { + states_[state]->SetFinal(std::move(weight)); + } + + StateId AddState() { + states_.push_back(new (&state_alloc_) State(arc_alloc_)); + return states_.size() - 1; + } + + StateId AddState(State *state) { + states_.push_back(state); + return states_.size() - 1; + } + + void AddArc(StateId state, const Arc &arc) { states_[state]->AddArc(arc); } + + void AddArc(StateId state, Arc &&arc) { + states_[state]->AddArc(std::move(arc)); + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + states_[state]->EmplaceArc(std::forward(ctor_args)...); + } + + void DeleteStates(const std::vector &dstates) { + std::vector newid(states_.size(), 0); + for (size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId state = 0; state < states_.size(); ++state) { + if (newid[state] != kNoStateId) { + newid[state] = nstates; + if (state != nstates) states_[nstates] = states_[state]; + ++nstates; + } else { + State::Destroy(states_[state], &state_alloc_); + } + } + states_.resize(nstates); + for (StateId state = 0; state < states_.size(); ++state) { + auto *arcs = states_[state]->MutableArcs(); + size_t narcs = 0; + auto nieps = states_[state]->NumInputEpsilons(); + auto noeps = states_[state]->NumOutputEpsilons(); + for (size_t i = 0; i < states_[state]->NumArcs(); ++i) { + const auto t = newid[arcs[i].nextstate]; + if (t != kNoStateId) { + arcs[i].nextstate = t; + if (i != narcs) arcs[narcs] = arcs[i]; + ++narcs; + } else { + if (arcs[i].ilabel == 0) --nieps; + if (arcs[i].olabel == 0) --noeps; + } + } + states_[state]->DeleteArcs(states_[state]->NumArcs() - narcs); + states_[state]->SetNumInputEpsilons(nieps); + states_[state]->SetNumOutputEpsilons(noeps); + } + if (Start() != kNoStateId) SetStart(newid[Start()]); + } + + void DeleteStates() { + for (size_t state = 0; state < states_.size(); ++state) { + State::Destroy(states_[state], &state_alloc_); + } + states_.clear(); + SetStart(kNoStateId); + } + + void DeleteArcs(StateId state, size_t n) { states_[state]->DeleteArcs(n); } + + void DeleteArcs(StateId state) { states_[state]->DeleteArcs(); } + + State *GetState(StateId state) { return states_[state]; } + + const State *GetState(StateId state) const { return states_[state]; } + + void SetState(StateId state, State *vstate) { states_[state] = vstate; } + + void ReserveStates(StateId n) { states_.reserve(n); } + + void ReserveArcs(StateId state, size_t n) { states_[state]->ReserveArcs(n); } + + // Provide information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = states_.size(); + } + + // Provide information needed for generic arc iterator. + void InitArcIterator(StateId state, ArcIteratorData *data) const { + data->base = nullptr; + data->narcs = states_[state]->NumArcs(); + data->arcs = states_[state]->Arcs(); + data->ref_count = nullptr; + } + + private: + std::vector states_; // States represenation. + StateId start_; // Initial state. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// This is a VectorFstBaseImpl container that holds VectorStates and manages FST +// properties. +template +class VectorFstImpl : public VectorFstBaseImpl { + public: + using State = S; + using Arc = typename State::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + + using VectorFstBaseImpl::Start; + using VectorFstBaseImpl::NumStates; + using VectorFstBaseImpl::GetState; + using VectorFstBaseImpl::ReserveArcs; + + friend class MutableArcIterator>; + + using BaseImpl = VectorFstBaseImpl; + + VectorFstImpl() { + SetType("vector"); + SetProperties(kNullProperties | kStaticProperties); + } + + explicit VectorFstImpl(const Fst &fst); + + static VectorFstImpl *Read(std::istream &strm, const FstReadOptions &opts); + + void SetStart(StateId state) { + BaseImpl::SetStart(state); + SetProperties(SetStartProperties(Properties())); + } + + void SetFinal(StateId state, Weight weight) { + const auto old_weight = BaseImpl::Final(state); + const auto properties = + SetFinalProperties(Properties(), old_weight, weight); + BaseImpl::SetFinal(state, std::move(weight)); + SetProperties(properties); + } + + StateId AddState() { + const auto state = BaseImpl::AddState(); + SetProperties(AddStateProperties(Properties())); + return state; + } + + void AddArc(StateId state, const Arc &arc) { + BaseImpl::AddArc(state, arc); + UpdatePropertiesAfterAddArc(state); + } + + void AddArc(StateId state, Arc &&arc) { + BaseImpl::AddArc(state, std::move(arc)); + UpdatePropertiesAfterAddArc(state); + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + BaseImpl::EmplaceArc(state, std::forward(ctor_args)...); + UpdatePropertiesAfterAddArc(state); + } + + void DeleteStates(const std::vector &dstates) { + BaseImpl::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + } + + void DeleteStates() { + BaseImpl::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), kStaticProperties)); + } + + void DeleteArcs(StateId state, size_t n) { + BaseImpl::DeleteArcs(state, n); + SetProperties(DeleteArcsProperties(Properties())); + } + + void DeleteArcs(StateId state) { + BaseImpl::DeleteArcs(state); + SetProperties(DeleteArcsProperties(Properties())); + } + + // Properties always true of this FST class + static constexpr uint64 kStaticProperties = kExpanded | kMutable; + + private: + void UpdatePropertiesAfterAddArc(StateId state) { + auto *vstate = GetState(state); + const size_t num_arcs{vstate->NumArcs()}; + if (num_arcs) { + const auto &arc = vstate->GetArc(num_arcs - 1); + const auto *parc = (num_arcs < 2) + ? nullptr + : &(vstate->GetArc(num_arcs - 2)); + SetProperties(AddArcProperties(Properties(), state, arc, parc)); + } + } + + // Minimum file format version supported. + static constexpr int kMinFileVersion = 2; +}; + +template +constexpr uint64 VectorFstImpl::kStaticProperties; + +template +constexpr int VectorFstImpl::kMinFileVersion; + +template +VectorFstImpl::VectorFstImpl(const Fst &fst) { + SetType("vector"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + BaseImpl::SetStart(fst.Start()); + if (fst.Properties(kExpanded, false)) { + BaseImpl::ReserveStates(CountStates(fst)); + } + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + BaseImpl::AddState(); + BaseImpl::SetFinal(state, fst.Final(state)); + ReserveArcs(state, fst.NumArcs(state)); + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + BaseImpl::AddArc(state, arc); + } + } + SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); +} + +template +VectorFstImpl *VectorFstImpl::Read(std::istream &strm, + const FstReadOptions &opts) { + std::unique_ptr> impl(new VectorFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->BaseImpl::SetStart(hdr.Start()); + if (hdr.NumStates() != kNoStateId) impl->ReserveStates(hdr.NumStates()); + StateId state = 0; + for (; hdr.NumStates() == kNoStateId || state < hdr.NumStates(); ++state) { + Weight weight; + if (!weight.Read(strm)) break; + impl->BaseImpl::AddState(); + auto *vstate = impl->GetState(state); + vstate->SetFinal(weight); + int64 narcs; + ReadType(strm, &narcs); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->ReserveArcs(state, narcs); + for (int64 i = 0; i < narcs; ++i) { + Arc arc; + ReadType(strm, &arc.ilabel); + ReadType(strm, &arc.olabel); + arc.weight.Read(strm); + ReadType(strm, &arc.nextstate); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->BaseImpl::AddArc(state, std::move(arc)); + } + } + if (hdr.NumStates() != kNoStateId && state != hdr.NumStates()) { + LOG(ERROR) << "VectorFst::Read: Unexpected end of file: " << opts.source; + return nullptr; + } + return impl.release(); +} + +} // namespace internal + +// Simple concrete, mutable FST. This class attaches interface to implementation +// and handles reference counting, delegating most methods to ImplToMutableFst. +// Also supports ReserveStates and ReserveArcs methods (cf. STL vector methods). +// The second optional template argument gives the State definition. +template */> +class VectorFst : public ImplToMutableFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using State = S; + using Impl = internal::VectorFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + friend class MutableArcIterator>; + + template + friend void Cast(const F &, G *); + + VectorFst() : ImplToMutableFst(std::make_shared()) {} + + explicit VectorFst(const Fst &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + VectorFst(const VectorFst &fst, bool safe = false) + : ImplToMutableFst(fst) {} + + VectorFst(VectorFst &&) noexcept; + + // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. + VectorFst *Copy(bool safe = false) const override { + return new VectorFst(*this, safe); + } + + VectorFst &operator=(const VectorFst &) = default; + + VectorFst &operator=(VectorFst &&) noexcept; + + VectorFst &operator=(const Fst &fst) override { + if (this != &fst) SetImpl(std::make_shared(fst)); + return *this; + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + MutateCheck(); + GetMutableImpl()->EmplaceArc(state, std::forward(ctor_args)...); + } + + // Reads a VectorFst from an input stream, returning nullptr on error. + static VectorFst *Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new VectorFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a VectorFst from a file, returning nullptr on error; empty filename + // reads from standard input. + static VectorFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst>::Read(filename); + return impl ? new VectorFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return WriteFst(*this, strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + inline void InitMutableArcIterator(StateId s, + MutableArcIteratorData *) override; + + using ImplToMutableFst>::ReserveArcs; + using ImplToMutableFst>::ReserveStates; + + private: + using ImplToMutableFst>::GetImpl; + using ImplToMutableFst>::GetMutableImpl; + using ImplToMutableFst>::MutateCheck; + using ImplToMutableFst>::SetImpl; + + explicit VectorFst(std::shared_ptr impl) + : ImplToMutableFst(impl) {} +}; + +template +inline VectorFst::VectorFst( + VectorFst &&fst) noexcept = default; + +template +inline VectorFst &VectorFst::operator=( + VectorFst &&fst) noexcept = default; + +// Writes FST to file in Vector format, potentially with a pass over the machine +// before writing to compute number of states. +template +template +bool VectorFst::WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts) { + static constexpr int file_version = 2; + bool update_header = true; + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(kNoStateId); + std::streampos start_offset = 0; + if (fst.Properties(kExpanded, false) || opts.stream_write || + (start_offset = strm.tellp()) != -1) { + hdr.SetNumStates(CountStates(fst)); + update_header = false; + } + const auto properties = + fst.Properties(kCopyProperties, false) | Impl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, + "vector", properties, &hdr); + StateId num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + fst.Final(s).Write(strm); + const int64 narcs = fst.NumArcs(s); + WriteType(strm, narcs); + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + WriteType(strm, arc.ilabel); + WriteType(strm, arc.olabel); + arc.weight.Write(strm); + WriteType(strm, arc.nextstate); + } + ++num_states; + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "VectorFst::Write: Write failed: " << opts.source; + return false; + } + if (update_header) { + hdr.SetNumStates(num_states); + return internal::FstImpl::UpdateFstHeader( + fst, strm, opts, file_version, "vector", properties, &hdr, + start_offset); + } else { + if (num_states != hdr.NumStates()) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + } + return true; +} + +// Specialization for VectorFst; see generic version in fst.h for sample usage +// (but use the VectorFst type instead). This version should inline. +template +class StateIterator> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const VectorFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + const StateId nstates_; + StateId s_; +}; + +// Specialization for VectorFst; see generic version in fst.h for sample usage +// (but use the VectorFst type instead). This version should inline. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const VectorFst &fst, StateId s) + : arcs_(fst.GetImpl()->GetState(s)->Arcs()), + narcs_(fst.GetImpl()->GetState(s)->NumArcs()), + i_(0) {} + + bool Done() const { return i_ >= narcs_; } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + size_t Position() const { return i_; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32, uint32) {} + + private: + const Arc *arcs_; + size_t narcs_; + size_t i_; +}; + +// Specialization for VectorFst; see generic version in mutable-fst.h for sample +// usage (but use the VectorFst type instead). This version should inline. +template +class MutableArcIterator> + : public MutableArcIteratorBase { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MutableArcIterator(VectorFst *fst, StateId s) : i_(0) { + fst->MutateCheck(); + state_ = fst->GetMutableImpl()->GetState(s); + properties_ = &fst->GetImpl()->properties_; + } + + bool Done() const final { return i_ >= state_->NumArcs(); } + + const Arc &Value() const final { return state_->GetArc(i_); } + + void Next() final { ++i_; } + + size_t Position() const final { return i_; } + + void Reset() final { i_ = 0; } + + void Seek(size_t a) final { i_ = a; } + + void SetValue(const Arc &arc) final { + const auto &oarc = state_->GetArc(i_); + if (oarc.ilabel != oarc.olabel) *properties_ &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + *properties_ &= ~kIEpsilons; + if (oarc.olabel == 0) *properties_ &= ~kEpsilons; + } + if (oarc.olabel == 0) *properties_ &= ~kOEpsilons; + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) { + *properties_ &= ~kWeighted; + } + state_->SetArc(arc, i_); + if (arc.ilabel != arc.olabel) { + *properties_ |= kNotAcceptor; + *properties_ &= ~kAcceptor; + } + if (arc.ilabel == 0) { + *properties_ |= kIEpsilons; + *properties_ &= ~kNoIEpsilons; + if (arc.olabel == 0) { + *properties_ |= kEpsilons; + *properties_ &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + *properties_ |= kOEpsilons; + *properties_ &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + *properties_ |= kWeighted; + *properties_ &= ~kUnweighted; + } + *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | + kNoOEpsilons | kWeighted | kUnweighted; + } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + State *state_; + uint64 *properties_; + size_t i_; +}; + +// Provides information needed for the generic mutable arc iterator. +template +inline void VectorFst::InitMutableArcIterator( + StateId s, MutableArcIteratorData *data) { + data->base = new MutableArcIterator>(this, s); +} + +// A useful alias when using StdArc. +using StdVectorFst = VectorFst; + +} // namespace fst + +#endif // FST_VECTOR_FST_H_ diff --git a/projects/llm_framework/include/fst/verify.h b/projects/llm_framework/include/fst/verify.h new file mode 100644 index 00000000..2ea8a64a --- /dev/null +++ b/projects/llm_framework/include/fst/verify.h @@ -0,0 +1,100 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to verify an FST's contents. + +#ifndef FST_VERIFY_H_ +#define FST_VERIFY_H_ + +#include + +#include +#include + + +namespace fst { + +// Verifies that an Fst's contents are sane. +template +bool Verify(const Fst &fst, bool allow_negative_labels = false) { + using StateId = typename Arc::StateId; + const auto start = fst.Start(); + const auto *isyms = fst.InputSymbols(); + const auto *osyms = fst.OutputSymbols(); + // Count states + StateId ns = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) ++ns; + if (start == kNoStateId && ns > 0) { + LOG(ERROR) << "Verify: FST start state ID not set"; + return false; + } else if (start >= ns) { + LOG(ERROR) << "Verify: FST start state ID exceeds number of states"; + return false; + } + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + auto state = siter.Value(); + size_t na = 0; + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (!allow_negative_labels && arc.ilabel < 0) { + LOG(ERROR) << "Verify: FST input label ID of arc at position " << na + << " of state " << state << " is negative"; + return false; + } else if (isyms && isyms->Find(arc.ilabel).empty()) { + LOG(ERROR) << "Verify: FST input label ID " << arc.ilabel + << " of arc at position " << na << " of state " << state + << " is missing from input symbol table \"" << isyms->Name() + << "\""; + return false; + } else if (!allow_negative_labels && arc.olabel < 0) { + LOG(ERROR) << "Verify: FST output label ID of arc at position " << na + << " of state " << state << " is negative"; + return false; + } else if (osyms && osyms->Find(arc.olabel).empty()) { + LOG(ERROR) << "Verify: FST output label ID " << arc.olabel + << " of arc at position " << na << " of state " << state + << " is missing from output symbol table \"" << osyms->Name() + << "\""; + return false; + } else if (!arc.weight.Member()) { + LOG(ERROR) << "Verify: FST weight of arc at position " << na + << " of state " << state << " is invalid"; + return false; + } else if (arc.nextstate < 0) { + LOG(ERROR) << "Verify: FST destination state ID of arc at position " + << na << " of state " << state << " is negative"; + return false; + } else if (arc.nextstate >= ns) { + LOG(ERROR) << "Verify: FST destination state ID of arc at position " + << na << " of state " << state + << " exceeds number of states"; + return false; + } + ++na; + } + if (!fst.Final(state).Member()) { + LOG(ERROR) << "Verify: FST final weight of state " << state + << " is invalid"; + return false; + } + } + const auto fst_props = fst.Properties(kFstProperties, false); + if (fst_props & kError) { + LOG(ERROR) << "Verify: FST error property is set"; + return false; + } + uint64 known_props; + uint64 test_props = + ComputeProperties(fst, kFstProperties, &known_props, false); + if (!CompatProperties(fst_props, test_props)) { + LOG(ERROR) << "Verify: Stored FST properties incorrect " + << "(props1 = stored props, props2 = tested)"; + return false; + } else { + return true; + } +} + +} // namespace fst + +#endif // FST_VERIFY_H_ diff --git a/projects/llm_framework/include/fst/visit.h b/projects/llm_framework/include/fst/visit.h new file mode 100644 index 00000000..c6658047 --- /dev/null +++ b/projects/llm_framework/include/fst/visit.h @@ -0,0 +1,321 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Queue-dependent visitation of finite-state transducers. See also dfs-visit.h. + +#ifndef FST_VISIT_H_ +#define FST_VISIT_H_ + + +#include +#include + + +namespace fst { + +// Visitor Interface: class determining actions taken during a visit. If any of +// the boolean member functions return false, the visit is aborted by first +// calling FinishState() on all unfinished (grey) states and then calling +// FinishVisit(). +// +// Note this is more general than the visitor interface in dfs-visit.h but lacks +// some DFS-specific behavior. +// +// template +// class Visitor { +// public: +// using StateId = typename Arc::StateId; +// +// Visitor(T *return_data); +// +// // Invoked before visit. +// void InitVisit(const Fst &fst); +// +// // Invoked when state discovered (2nd arg is visitation root). +// bool InitState(StateId s, StateId root); +// +// // Invoked when arc to white/undiscovered state examined. +// bool WhiteArc(StateId s, const Arc &arc); +// +// // Invoked when arc to grey/unfinished state examined. +// bool GreyArc(StateId s, const Arc &arc); +// +// // Invoked when arc to black/finished state examined. +// bool BlackArc(StateId s, const Arc &arc); +// +// // Invoked when state finished. +// void FinishState(StateId s); +// +// // Invoked after visit. +// void FinishVisit(); +// }; + +// Performs queue-dependent visitation. Visitor class argument determines +// actions and contains any return data. ArcFilter determines arcs that are +// considered. If 'access_only' is true, performs visitation only to states +// accessible from the initial state. +template +void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, + bool access_only = false) { + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + visitor->InitVisit(fst); + const auto start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + // An FST's state's visit color. + static constexpr uint8 kWhiteState = 0x01; // Undiscovered. + static constexpr uint8 kGreyState = 0x02; // Discovered & unfinished. + static constexpr uint8 kBlackState = 0x04; // Finished. + // We destroy an iterator as soon as possible and mark it so. + static constexpr uint8 kArcIterDone = 0x08; + std::vector state_status; + std::vector *> arc_iterator; + MemoryPool> aiter_pool; + StateId nstates = start + 1; // Number of known states in general case. + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // Tests if expanded, then uses + nstates = CountStates(fst); // ExpandedFst::NumStates(). + expanded = true; + } + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + StateIterator> siter(fst); + // Continues visit while true. + bool visit = true; + // Iterates over trees in visit forest. + for (auto root = start; visit && root < nstates;) { + visit = visitor->InitState(root, root); + state_status[root] = kGreyState; + queue->Enqueue(root); + while (!queue->Empty()) { + auto state = queue->Head(); + if (state >= state_status.size()) { + nstates = state + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Creates arc iterator if needed. + if (!arc_iterator[state] && !(state_status[state] & kArcIterDone) && + visit) { + arc_iterator[state] = new (&aiter_pool) ArcIterator(fst, state); + } + // Deletes arc iterator if done. + auto *aiter = arc_iterator[state]; + if ((aiter && aiter->Done()) || !visit) { + Destroy(aiter, &aiter_pool); + arc_iterator[state] = nullptr; + state_status[state] |= kArcIterDone; + } + // Dequeues state and marks black if done. + if (state_status[state] & kArcIterDone) { + queue->Dequeue(); + visitor->FinishState(state); + state_status[state] = kBlackState; + continue; + } + const auto &arc = aiter->Value(); + if (arc.nextstate >= state_status.size()) { + nstates = arc.nextstate + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Visits respective arc types. + if (filter(arc)) { + // Enqueues destination state and marks grey if white. + if (state_status[arc.nextstate] == kWhiteState) { + visit = visitor->WhiteArc(state, arc); + if (!visit) continue; + visit = visitor->InitState(arc.nextstate, root); + state_status[arc.nextstate] = kGreyState; + queue->Enqueue(arc.nextstate); + } else if (state_status[arc.nextstate] == kBlackState) { + visit = visitor->BlackArc(state, arc); + } else { + visit = visitor->GreyArc(state, arc); + } + } + aiter->Next(); + // Destroys an iterator ASAP for efficiency. + if (aiter->Done()) { + Destroy(aiter, &aiter_pool); + arc_iterator[state] = nullptr; + state_status[state] |= kArcIterDone; + } + } + if (access_only) break; + // Finds next tree root. + for (root = (root == start) ? 0 : root + 1; + root < nstates && state_status[root] != kWhiteState; ++root) { + } + // Check for a state beyond the largest known state. + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_status.push_back(kWhiteState); + arc_iterator.push_back(nullptr); + break; + } + } + } + } + visitor->FinishVisit(); +} + +template +inline void Visit(const Fst &fst, Visitor *visitor, Queue *queue) { + Visit(fst, visitor, queue, AnyArcFilter()); +} + +// Copies input FST to mutable FST following queue order. +template +class CopyVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + explicit CopyVisitor(MutableFst *ofst) : ifst_(nullptr), ofst_(ofst) {} + + void InitVisit(const Fst &ifst) { + ifst_ = &ifst; + ofst_->DeleteStates(); + ofst_->SetStart(ifst_->Start()); + } + + bool InitState(StateId state, StateId) { + while (ofst_->NumStates() <= state) ofst_->AddState(); + return true; + } + + bool WhiteArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + bool GreyArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + bool BlackArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + void FinishState(StateId state) { + ofst_->SetFinal(state, ifst_->Final(state)); + } + + void FinishVisit() {} + + private: + const Fst *ifst_; + MutableFst *ofst_; +}; + +// Visits input FST up to a state limit following queue order. +template +class PartialVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + explicit PartialVisitor(StateId maxvisit) + : fst_(nullptr), maxvisit_(maxvisit) {} + + void InitVisit(const Fst &ifst) { + fst_ = &ifst; + ninit_ = 0; + nfinish_ = 0; + } + + bool InitState(StateId state, StateId root) { + ++ninit_; + return ninit_ <= maxvisit_; + } + + bool WhiteArc(StateId state, const Arc &arc) { return true; } + + bool GreyArc(StateId state, const Arc &arc) { return true; } + + bool BlackArc(StateId state, const Arc &arc) { return true; } + + void FinishState(StateId state) { + fst_->Final(state); // Visits super-final arc. + ++nfinish_; + } + + void FinishVisit() {} + + StateId NumInitialized() { return ninit_; } + + StateId NumFinished() { return nfinish_; } + + private: + const Fst *fst_; + StateId maxvisit_; + StateId ninit_; + StateId nfinish_; +}; + +// Copies input FST to mutable FST up to a state limit following queue order. +template +class PartialCopyVisitor : public CopyVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using CopyVisitor::WhiteArc; + + PartialCopyVisitor(MutableFst *ofst, StateId maxvisit, + bool copy_grey = true, bool copy_black = true) + : CopyVisitor(ofst), maxvisit_(maxvisit), + copy_grey_(copy_grey), copy_black_(copy_black) {} + + void InitVisit(const Fst &ifst) { + CopyVisitor::InitVisit(ifst); + ninit_ = 0; + nfinish_ = 0; + } + + bool InitState(StateId state, StateId root) { + CopyVisitor::InitState(state, root); + ++ninit_; + return ninit_ <= maxvisit_; + } + + bool GreyArc(StateId state, const Arc &arc) { + if (copy_grey_) return CopyVisitor::GreyArc(state, arc); + return true; + } + + bool BlackArc(StateId state, const Arc &arc) { + if (copy_black_) return CopyVisitor::BlackArc(state, arc); + return true; + } + + void FinishState(StateId state) { + CopyVisitor::FinishState(state); + ++nfinish_; + } + + void FinishVisit() {} + + StateId NumInitialized() { return ninit_; } + + StateId NumFinished() { return nfinish_; } + + private: + StateId maxvisit_; + StateId ninit_; + StateId nfinish_; + const bool copy_grey_; + const bool copy_black_; +}; + +} // namespace fst + +#endif // FST_VISIT_H_ diff --git a/projects/llm_framework/include/fst/weight.h b/projects/llm_framework/include/fst/weight.h new file mode 100644 index 00000000..ea012ec6 --- /dev/null +++ b/projects/llm_framework/include/fst/weight.h @@ -0,0 +1,389 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// General weight set and associated semiring operation definitions. + +#ifndef FST_WEIGHT_H_ +#define FST_WEIGHT_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +// A semiring is specified by two binary operations Plus and Times and two +// designated elements Zero and One with the following properties: +// +// Plus: associative, commutative, and has Zero as its identity. +// +// Times: associative and has identity One, distributes w.r.t. Plus, and +// has Zero as an annihilator: +// Times(Zero(), a) == Times(a, Zero()) = Zero(). +// +// A left semiring distributes on the left; a right semiring is similarly +// defined. +// +// A Weight class must have binary functions Plus and Times and static member +// functions Zero() and One() and these must form (at least) a left or right +// semiring. +// +// In addition, the following should be defined for a Weight: +// +// Member: predicate on set membership. +// +// NoWeight: static member function that returns an element that is +// not a set member; used to signal an error. +// +// >>: reads textual representation of a weight. +// +// <<: prints textual representation of a weight. +// +// Read(istream &istrm): reads binary representation of a weight. +// +// Write(ostream &ostrm): writes binary representation of a weight. +// +// Hash: maps weight to size_t. +// +// ApproxEqual: approximate equality (for inexact weights) +// +// Quantize: quantizes w.r.t delta (for inexact weights) +// +// Divide: +// - In a left semiring, for all a, b, b', c: +// if Times(a, b) = c, Divide(c, a, DIVIDE_LEFT) = b' and b'.Member(), +// then Times(a, b') = c. +// - In a right semiring, for all a, a', b, c: +// if Times(a, b) = c, Divide(c, b, DIVIDE_RIGHT) = a' and a'.Member(), +// then Times(a', b) = c. +// - In a commutative semiring, +// * for all a, c: +// Divide(c, a, DIVIDE_ANY) = Divide(c, a, DIVIDE_LEFT) +// = Divide(c, a, DIVIDE_RIGHT) +// * for all a, b, b', c: +// if Times(a, b), Divide(c, a, DIVIDE_ANY) = b' and b'.Member(), +// then Times(a, b') = c +// - In the case where there exist no b such that c = Times(a, b), the +// return value of Divide(c, a, DIVIDE_LEFT) is unspecified. Returning +// Weight::NoWeight() is recommemded but not required in order to +// allow the most efficient implementation. +// - All algorithms in this library only call Divide(c, a) when it is +// guaranteed that there exists a b such that c = Times(a, b). +// +// ReverseWeight: the type of the corresponding reverse weight. +// +// Typically the same type as Weight for a (both left and right) semiring. +// For the left string semiring, it is the right string semiring. +// +// Reverse: a mapping from Weight to ReverseWeight s.t. +// +// --> Reverse(Reverse(a)) = a +// --> Reverse(Plus(a, b)) = Plus(Reverse(a), Reverse(b)) +// --> Reverse(Times(a, b)) = Times(Reverse(b), Reverse(a)) +// Typically the identity mapping in a (both left and right) semiring. +// In the left string semiring, it maps to the reverse string in the right +// string semiring. +// +// Properties: specifies additional properties that hold: +// LeftSemiring: indicates weights form a left semiring. +// RightSemiring: indicates weights form a right semiring. +// Commutative: for all a,b: Times(a,b) == Times(b,a) +// Idempotent: for all a: Plus(a, a) == a. +// Path: for all a, b: Plus(a, b) == a or Plus(a, b) == b. + +// CONSTANT DEFINITIONS + +// A representable float near .001. +constexpr float kDelta = 1.0F / 1024.0F; + +// For all a, b, c: Times(c, Plus(a, b)) = Plus(Times(c, a), Times(c, b)). +constexpr uint64 kLeftSemiring = 0x0000000000000001ULL; + +// For all a, b, c: Times(Plus(a, b), c) = Plus(Times(a, c), Times(b, c)). +constexpr uint64 kRightSemiring = 0x0000000000000002ULL; + +constexpr uint64 kSemiring = kLeftSemiring | kRightSemiring; + +// For all a, b: Times(a, b) = Times(b, a). +constexpr uint64 kCommutative = 0x0000000000000004ULL; + +// For all a: Plus(a, a) = a. +constexpr uint64 kIdempotent = 0x0000000000000008ULL; + +// For all a, b: Plus(a, b) = a or Plus(a, b) = b. +constexpr uint64 kPath = 0x0000000000000010ULL; + +// For random weight generation: default number of distinct weights. +// This is also used for a few other weight generation defaults. +constexpr size_t kNumRandomWeights = 5; + +// Weight property boolean constants needed for SFINAE. + +// MSVC compiler bug workaround: an expression containing W::Properties() cannot +// be directly used as a value argument to std::enable_if or integral_constant. +// WeightPropertiesThunk::Properties works instead, however. +namespace bug { +template +struct WeightPropertiesThunk { + WeightPropertiesThunk() = delete; + constexpr static const uint64 Properties = W::Properties(); +}; + +template +using TestWeightProperties = std::integral_constant::Properties & props) == props>; +} // namespace bug + +template +using IsIdempotent = bug::TestWeightProperties; + +template +using IsPath = bug::TestWeightProperties; + + +// Determines direction of division. +enum DivideType { + DIVIDE_LEFT, // left division + DIVIDE_RIGHT, // right division + DIVIDE_ANY +}; // division in a commutative semiring + +// NATURAL ORDER +// +// By definition: +// +// a <= b iff a + b = a +// +// The natural order is a negative partial order iff the semiring is +// idempotent. It is trivially monotonic for plus. It is left +// (resp. right) monotonic for times iff the semiring is left +// (resp. right) distributive. It is a total order iff the semiring +// has the path property. +// +// For more information, see: +// +// Mohri, M. 2002. Semiring framework and algorithms for shortest-distance +// problems, Journal of Automata, Languages and +// Combinatorics 7(3): 321-350, 2002. +// +// We define the strict version of this order below. + +template +class NaturalLess { +public: + using Weight = W; + + NaturalLess() { + if (!(W::Properties() & kIdempotent)) { + FSTERROR() << "NaturalLess: Weight type is not idempotent: " << W::Type(); + } + } + + bool operator()(const W &w1, const W &w2) const { + return (Plus(w1, w2) == w1) && w1 != w2; + } +}; + +// Power is the iterated product for arbitrary semirings such that Power(w, 0) +// is One() for the semiring, and Power(w, n) = Times(Power(w, n - 1), w). +template +Weight Power(const Weight &weight, size_t n) { + auto result = Weight::One(); + for (size_t i = 0; i < n; ++i) result = Times(result, weight); + return result; +} + +// Simple default adder class. Specializations might be more complex. +template +class Adder { + public: + explicit Adder(Weight w = Weight::Zero()) : sum_(w) { } + + Weight Add(const Weight &w) { + sum_ = Plus(sum_, w); + return sum_; + } + + Weight Sum() { return sum_; } + + void Reset(Weight w = Weight::Zero()) { sum_ = w; } + + private: + Weight sum_; +}; + +// General weight converter: raises error. +template +struct WeightConvert { + W2 operator()(W1 w1) const { + FSTERROR() << "WeightConvert: Can't convert weight from \"" << W1::Type() + << "\" to \"" << W2::Type(); + return W2::NoWeight(); + } +}; + +// Specialized weight converter to self. +template +struct WeightConvert { + W operator()(W weight) const { return weight; } +}; + +// General random weight generator: raises error. +template +struct WeightGenerate { + W operator()() const { + FSTERROR() << "WeightGenerate: No random generator for " << W::Type(); + return W::NoWeight(); + } +}; + +namespace internal { + +class CompositeWeightIO { + public: + CompositeWeightIO(); + CompositeWeightIO(char separator, std::pair parentheses); + + std::pair parentheses() const { + return {open_paren_, close_paren_}; + } + char separator() const { return separator_; } + + bool error() const { return error_; } + + protected: + const char separator_; + const char open_paren_; + const char close_paren_; + + private: + bool error_; +}; + +} // namespace internal + +// Helper class for writing textual composite weights. +class CompositeWeightWriter : public internal::CompositeWeightIO { + public: + // Uses configuration from flags (FLAGS_fst_weight_separator, + // FLAGS_fst_weight_parentheses). + explicit CompositeWeightWriter(std::ostream &ostrm); + + // parentheses defines the opening and closing parenthesis characters. + // Set parentheses = {0, 0} to disable writing parenthesis. + CompositeWeightWriter(std::ostream &ostrm, char separator, + std::pair parentheses); + + CompositeWeightWriter(const CompositeWeightWriter &) = delete; + CompositeWeightWriter &operator=(const CompositeWeightWriter &) = delete; + + // Writes open parenthesis to a stream if option selected. + void WriteBegin(); + + // Writes element to a stream. + template + void WriteElement(const T &comp) { + if (i_++ > 0) ostrm_ << separator_; + ostrm_ << comp; + } + + // Writes close parenthesis to a stream if option selected. + void WriteEnd(); + + private: + std::ostream &ostrm_; + int i_ = 0; // Element position. +}; + +// Helper class for reading textual composite weights. Elements are separated by +// a separator character. There must be at least one element per textual +// representation. Parentheses characters should be set if the composite +// weights themselves contain composite weights to ensure proper parsing. +class CompositeWeightReader : public internal::CompositeWeightIO { + public: + // Uses configuration from flags (FLAGS_fst_weight_separator, + // FLAGS_fst_weight_parentheses). + explicit CompositeWeightReader(std::istream &istrm); + + // parentheses defines the opening and closing parenthesis characters. + // Set parentheses = {0, 0} to disable reading parenthesis. + CompositeWeightReader(std::istream &istrm, char separator, + std::pair parentheses); + + CompositeWeightReader(const CompositeWeightReader &) = delete; + CompositeWeightReader &operator=(const CompositeWeightReader &) = delete; + + // Reads open parenthesis from a stream if option selected. + void ReadBegin(); + + // Reads element from a stream. The second argument, when true, indicates that + // this will be the last element (allowing more forgiving formatting of the + // last element). Returns false when last element is read. + template + bool ReadElement(T *comp, bool last = false); + + // Finalizes reading. + void ReadEnd(); + + private: + std::istream &istrm_; // Input stream. + int c_ = 0; // Last character read, or EOF. + int depth_ = 0; // Weight parentheses depth. +}; + +template +inline bool CompositeWeightReader::ReadElement(T *comp, bool last) { + string s; + const bool has_parens = open_paren_ != 0; + while ((c_ != std::istream::traits_type::eof()) && !std::isspace(c_) && + (c_ != separator_ || depth_ > 1 || last) && + (c_ != close_paren_ || depth_ != 1)) { + s += c_; + // If parentheses encountered before separator, they must be matched. + if (has_parens && c_ == open_paren_) { + ++depth_; + } else if (has_parens && c_ == close_paren_) { + // Failure on unmatched parentheses. + if (depth_ == 0) { + FSTERROR() << "CompositeWeightReader: Unmatched close paren: " + << "Is the fst_weight_parentheses flag set correctly?"; + istrm_.clear(std::ios::badbit); + return false; + } + --depth_; + } + c_ = istrm_.get(); + } + if (s.empty()) { + FSTERROR() << "CompositeWeightReader: Empty element: " + << "Is the fst_weight_parentheses flag set correctly?"; + istrm_.clear(std::ios::badbit); + return false; + } + std::istringstream istrm(s); + istrm >> *comp; + // Skips separator/close parenthesis. + if (c_ != std::istream::traits_type::eof() && !std::isspace(c_)) { + c_ = istrm_.get(); + } + const bool is_eof = c_ == std::istream::traits_type::eof(); + // Clears fail bit if just EOF. + if (is_eof && !istrm_.bad()) istrm_.clear(std::ios::eofbit); + return !is_eof && !std::isspace(c_); +} + +} // namespace fst + +#endif // FST_WEIGHT_H_ diff --git a/projects/llm_framework/include/gflags/defines.h b/projects/llm_framework/include/gflags/defines.h new file mode 100644 index 00000000..944ed7db --- /dev/null +++ b/projects/llm_framework/include/gflags/defines.h @@ -0,0 +1,48 @@ +/* Generated from defines.h.in during build configuration using CMake. */ + +// Note: This header file is only used internally. It is not part of public interface! +// Any cmakedefine is defined using the -D flag instead when Bazel is used. +// For Bazel, this file is thus not used to avoid a private file in $(GENDIR). + +#ifndef GFLAGS_DEFINES_H_ +#define GFLAGS_DEFINES_H_ + + +// Define if you build this library for a MS Windows OS. +/* #undef OS_WINDOWS */ + +// Define if you have the header file. +#define HAVE_STDINT_H + +// Define if you have the header file. +#define HAVE_SYS_TYPES_H + +// Define if you have the header file. +#define HAVE_INTTYPES_H + +// Define if you have the header file. +#define HAVE_SYS_STAT_H + +// Define if you have the header file. +#define HAVE_UNISTD_H + +// Define if you have the header file. +#define HAVE_FNMATCH_H + +// Define if you have the header file (Windows 2000/XP). +/* #undef HAVE_SHLWAPI_H */ + +// Define if you have the strtoll function. +#define HAVE_STRTOLL + +// Define if you have the strtoq function. +/* #undef HAVE_STRTOQ */ + +// Define if you have the header file. +/* #undef HAVE_PTHREAD */ + +// Define if your pthread library defines the type pthread_rwlock_t +/* #undef HAVE_RWLOCK */ + + +#endif // GFLAGS_DEFINES_H_ diff --git a/projects/llm_framework/include/gflags/gflags.h b/projects/llm_framework/include/gflags/gflags.h new file mode 100644 index 00000000..9273da8d --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags.h @@ -0,0 +1,626 @@ +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// Revamped and reorganized by Craig Silverstein +// +// This is the file that should be included by any file which declares +// or defines a command line flag or wants to parse command line flags +// or print a program usage message (which will include information about +// flags). Executive summary, in the form of an example foo.cc file: +// +// #include "foo.h" // foo.h has a line "DECLARE_int32(start);" +// #include "validators.h" // hypothetical file defining ValidateIsFile() +// +// DEFINE_int32(end, 1000, "The last record to read"); +// +// DEFINE_string(filename, "my_file.txt", "The file to read"); +// // Crash if the specified file does not exist. +// static bool dummy = RegisterFlagValidator(&FLAGS_filename, +// &ValidateIsFile); +// +// DECLARE_bool(verbose); // some other file has a DEFINE_bool(verbose, ...) +// +// void MyFunc() { +// if (FLAGS_verbose) printf("Records %d-%d\n", FLAGS_start, FLAGS_end); +// } +// +// Then, at the command-line: +// ./foo --noverbose --start=5 --end=100 +// +// For more details, see +// doc/gflags.html +// +// --- A note about thread-safety: +// +// We describe many functions in this routine as being thread-hostile, +// thread-compatible, or thread-safe. Here are the meanings we use: +// +// thread-safe: it is safe for multiple threads to call this routine +// (or, when referring to a class, methods of this class) +// concurrently. +// thread-hostile: it is not safe for multiple threads to call this +// routine (or methods of this class) concurrently. In gflags, +// most thread-hostile routines are intended to be called early in, +// or even before, main() -- that is, before threads are spawned. +// thread-compatible: it is safe for multiple threads to read from +// this variable (when applied to variables), or to call const +// methods of this class (when applied to classes), as long as no +// other thread is writing to the variable or calling non-const +// methods of this class. + +#ifndef GFLAGS_GFLAGS_H_ +#define GFLAGS_GFLAGS_H_ + +#include +#include + +#include "gflags/gflags_declare.h" // IWYU pragma: export + + +// We always want to export variables defined in user code +#ifndef GFLAGS_DLL_DEFINE_FLAG +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DEFINE_FLAG __declspec(dllexport) +# else +# define GFLAGS_DLL_DEFINE_FLAG +# endif +#endif + + +namespace GFLAGS_NAMESPACE { + + +// -------------------------------------------------------------------- +// To actually define a flag in a file, use DEFINE_bool, +// DEFINE_string, etc. at the bottom of this file. You may also find +// it useful to register a validator with the flag. This ensures that +// when the flag is parsed from the commandline, or is later set via +// SetCommandLineOption, we call the validation function. It is _not_ +// called when you assign the value to the flag directly using the = operator. +// +// The validation function should return true if the flag value is valid, and +// false otherwise. If the function returns false for the new setting of the +// flag, the flag will retain its current value. If it returns false for the +// default value, ParseCommandLineFlags() will die. +// +// This function is safe to call at global construct time (as in the +// example below). +// +// Example use: +// static bool ValidatePort(const char* flagname, int32 value) { +// if (value > 0 && value < 32768) // value is ok +// return true; +// printf("Invalid value for --%s: %d\n", flagname, (int)value); +// return false; +// } +// DEFINE_int32(port, 0, "What port to listen on"); +// static bool dummy = RegisterFlagValidator(&FLAGS_port, &ValidatePort); + +// Returns true if successfully registered, false if not (because the +// first argument doesn't point to a command-line flag, or because a +// validator is already registered for this flag). +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const bool* flag, bool (*validate_fn)(const char*, bool)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const int32* flag, bool (*validate_fn)(const char*, int32)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const uint32* flag, bool (*validate_fn)(const char*, uint32)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const int64* flag, bool (*validate_fn)(const char*, int64)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const uint64* flag, bool (*validate_fn)(const char*, uint64)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const double* flag, bool (*validate_fn)(const char*, double)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const std::string* flag, bool (*validate_fn)(const char*, const std::string&)); + +// Convenience macro for the registration of a flag validator +#define DEFINE_validator(name, validator) \ + static const bool name##_validator_registered = \ + GFLAGS_NAMESPACE::RegisterFlagValidator(&FLAGS_##name, validator) + + +// -------------------------------------------------------------------- +// These methods are the best way to get access to info about the +// list of commandline flags. Note that these routines are pretty slow. +// GetAllFlags: mostly-complete info about the list, sorted by file. +// ShowUsageWithFlags: pretty-prints the list to stdout (what --help does) +// ShowUsageWithFlagsRestrict: limit to filenames with restrict as a substr +// +// In addition to accessing flags, you can also access argv[0] (the program +// name) and argv (the entire commandline), which we sock away a copy of. +// These variables are static, so you should only set them once. +// +// No need to export this data only structure from DLL, avoiding VS warning 4251. +struct CommandLineFlagInfo { + std::string name; // the name of the flag + std::string type; // the type of the flag: int32, etc + std::string description; // the "help text" associated with the flag + std::string current_value; // the current value, as a string + std::string default_value; // the default value, as a string + std::string filename; // 'cleaned' version of filename holding the flag + bool has_validator_fn; // true if RegisterFlagValidator called on this flag + bool is_default; // true if the flag has the default value and + // has not been set explicitly from the cmdline + // or via SetCommandLineOption + const void* flag_ptr; // pointer to the flag's current value (i.e. FLAGS_foo) +}; + +// Using this inside of a validator is a recipe for a deadlock. +// TODO(user) Fix locking when validators are running, to make it safe to +// call validators during ParseAllFlags. +// Also make sure then to uncomment the corresponding unit test in +// gflags_unittest.sh +extern GFLAGS_DLL_DECL void GetAllFlags(std::vector* OUTPUT); +// These two are actually defined in gflags_reporting.cc. +extern GFLAGS_DLL_DECL void ShowUsageWithFlags(const char *argv0); // what --help does +extern GFLAGS_DLL_DECL void ShowUsageWithFlagsRestrict(const char *argv0, const char *restrict); + +// Create a descriptive string for a flag. +// Goes to some trouble to make pretty line breaks. +extern GFLAGS_DLL_DECL std::string DescribeOneFlag(const CommandLineFlagInfo& flag); + +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetArgv(int argc, const char** argv); + +// The following functions are thread-safe as long as SetArgv() is +// only called before any threads start. +extern GFLAGS_DLL_DECL const std::vector& GetArgvs(); +extern GFLAGS_DLL_DECL const char* GetArgv(); // all of argv as a string +extern GFLAGS_DLL_DECL const char* GetArgv0(); // only argv0 +extern GFLAGS_DLL_DECL uint32 GetArgvSum(); // simple checksum of argv +extern GFLAGS_DLL_DECL const char* ProgramInvocationName(); // argv0, or "UNKNOWN" if not set +extern GFLAGS_DLL_DECL const char* ProgramInvocationShortName(); // basename(argv0) + +// ProgramUsage() is thread-safe as long as SetUsageMessage() is only +// called before any threads start. +extern GFLAGS_DLL_DECL const char* ProgramUsage(); // string set by SetUsageMessage() + +// VersionString() is thread-safe as long as SetVersionString() is only +// called before any threads start. +extern GFLAGS_DLL_DECL const char* VersionString(); // string set by SetVersionString() + + + +// -------------------------------------------------------------------- +// Normally you access commandline flags by just saying "if (FLAGS_foo)" +// or whatever, and set them by calling "FLAGS_foo = bar" (or, more +// commonly, via the DEFINE_foo macro). But if you need a bit more +// control, we have programmatic ways to get/set the flags as well. +// These programmatic ways to access flags are thread-safe, but direct +// access is only thread-compatible. + +// Return true iff the flagname was found. +// OUTPUT is set to the flag's value, or unchanged if we return false. +extern GFLAGS_DLL_DECL bool GetCommandLineOption(const char* name, std::string* OUTPUT); + +// Return true iff the flagname was found. OUTPUT is set to the flag's +// CommandLineFlagInfo or unchanged if we return false. +extern GFLAGS_DLL_DECL bool GetCommandLineFlagInfo(const char* name, CommandLineFlagInfo* OUTPUT); + +// Return the CommandLineFlagInfo of the flagname. exit() if name not found. +// Example usage, to check if a flag's value is currently the default value: +// if (GetCommandLineFlagInfoOrDie("foo").is_default) ... +extern GFLAGS_DLL_DECL CommandLineFlagInfo GetCommandLineFlagInfoOrDie(const char* name); + +enum GFLAGS_DLL_DECL FlagSettingMode { + // update the flag's value (can call this multiple times). + SET_FLAGS_VALUE, + // update the flag's value, but *only if* it has not yet been updated + // with SET_FLAGS_VALUE, SET_FLAG_IF_DEFAULT, or "FLAGS_xxx = nondef". + SET_FLAG_IF_DEFAULT, + // set the flag's default value to this. If the flag has not yet updated + // yet (via SET_FLAGS_VALUE, SET_FLAG_IF_DEFAULT, or "FLAGS_xxx = nondef") + // change the flag's current value to the new default value as well. + SET_FLAGS_DEFAULT +}; + +// Set a particular flag ("command line option"). Returns a string +// describing the new value that the option has been set to. The +// return value API is not well-specified, so basically just depend on +// it to be empty if the setting failed for some reason -- the name is +// not a valid flag name, or the value is not a valid value -- and +// non-empty else. + +// SetCommandLineOption uses set_mode == SET_FLAGS_VALUE (the common case) +extern GFLAGS_DLL_DECL std::string SetCommandLineOption (const char* name, const char* value); +extern GFLAGS_DLL_DECL std::string SetCommandLineOptionWithMode(const char* name, const char* value, FlagSettingMode set_mode); + + +// -------------------------------------------------------------------- +// Saves the states (value, default value, whether the user has set +// the flag, registered validators, etc) of all flags, and restores +// them when the FlagSaver is destroyed. This is very useful in +// tests, say, when you want to let your tests change the flags, but +// make sure that they get reverted to the original states when your +// test is complete. +// +// Example usage: +// void TestFoo() { +// FlagSaver s1; +// FLAG_foo = false; +// FLAG_bar = "some value"; +// +// // test happens here. You can return at any time +// // without worrying about restoring the FLAG values. +// } +// +// Note: This class is marked with GFLAGS_ATTRIBUTE_UNUSED because all +// the work is done in the constructor and destructor, so in the standard +// usage example above, the compiler would complain that it's an +// unused variable. +// +// This class is thread-safe. However, its destructor writes to +// exactly the set of flags that have changed value during its +// lifetime, so concurrent _direct_ access to those flags +// (i.e. FLAGS_foo instead of {Get,Set}CommandLineOption()) is unsafe. + +class GFLAGS_DLL_DECL FlagSaver { + public: + FlagSaver(); + ~FlagSaver(); + + private: + class FlagSaverImpl* impl_; // we use pimpl here to keep API steady + + FlagSaver(const FlagSaver&); // no copying! + void operator=(const FlagSaver&); +}__attribute((unused)); + +// -------------------------------------------------------------------- +// Some deprecated or hopefully-soon-to-be-deprecated functions. + +// This is often used for logging. TODO(csilvers): figure out a better way +extern GFLAGS_DLL_DECL std::string CommandlineFlagsIntoString(); +// Usually where this is used, a FlagSaver should be used instead. +extern GFLAGS_DLL_DECL +bool ReadFlagsFromString(const std::string& flagfilecontents, + const char* prog_name, + bool errors_are_fatal); // uses SET_FLAGS_VALUE + +// These let you manually implement --flagfile functionality. +// DEPRECATED. +extern GFLAGS_DLL_DECL bool AppendFlagsIntoFile(const std::string& filename, const char* prog_name); +extern GFLAGS_DLL_DECL bool ReadFromFlagsFile(const std::string& filename, const char* prog_name, bool errors_are_fatal); // uses SET_FLAGS_VALUE + + +// -------------------------------------------------------------------- +// Useful routines for initializing flags from the environment. +// In each case, if 'varname' does not exist in the environment +// return defval. If 'varname' does exist but is not valid +// (e.g., not a number for an int32 flag), abort with an error. +// Otherwise, return the value. NOTE: for booleans, for true use +// 't' or 'T' or 'true' or '1', for false 'f' or 'F' or 'false' or '0'. + +extern GFLAGS_DLL_DECL bool BoolFromEnv(const char *varname, bool defval); +extern GFLAGS_DLL_DECL int32 Int32FromEnv(const char *varname, int32 defval); +extern GFLAGS_DLL_DECL uint32 Uint32FromEnv(const char *varname, uint32 defval); +extern GFLAGS_DLL_DECL int64 Int64FromEnv(const char *varname, int64 defval); +extern GFLAGS_DLL_DECL uint64 Uint64FromEnv(const char *varname, uint64 defval); +extern GFLAGS_DLL_DECL double DoubleFromEnv(const char *varname, double defval); +extern GFLAGS_DLL_DECL const char *StringFromEnv(const char *varname, const char *defval); + + +// -------------------------------------------------------------------- +// The next two functions parse gflags from main(): + +// Set the "usage" message for this program. For example: +// string usage("This program does nothing. Sample usage:\n"); +// usage += argv[0] + " "; +// SetUsageMessage(usage); +// Do not include commandline flags in the usage: we do that for you! +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetUsageMessage(const std::string& usage); + +// Sets the version string, which is emitted with --version. +// For instance: SetVersionString("1.3"); +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetVersionString(const std::string& version); + + +// Looks for flags in argv and parses them. Rearranges argv to put +// flags first, or removes them entirely if remove_flags is true. +// If a flag is defined more than once in the command line or flag +// file, the last definition is used. Returns the index (into argv) +// of the first non-flag argument. +// See top-of-file for more details on this function. +#ifndef SWIG // In swig, use ParseCommandLineFlagsScript() instead. +extern GFLAGS_DLL_DECL uint32 ParseCommandLineFlags(int *argc, char*** argv, bool remove_flags); +#endif + + +// Calls to ParseCommandLineNonHelpFlags and then to +// HandleCommandLineHelpFlags can be used instead of a call to +// ParseCommandLineFlags during initialization, in order to allow for +// changing default values for some FLAGS (via +// e.g. SetCommandLineOptionWithMode calls) between the time of +// command line parsing and the time of dumping help information for +// the flags as a result of command line parsing. If a flag is +// defined more than once in the command line or flag file, the last +// definition is used. Returns the index (into argv) of the first +// non-flag argument. (If remove_flags is true, will always return 1.) +extern GFLAGS_DLL_DECL uint32 ParseCommandLineNonHelpFlags(int *argc, char*** argv, bool remove_flags); + +// This is actually defined in gflags_reporting.cc. +// This function is misnamed (it also handles --version, etc.), but +// it's too late to change that now. :-( +extern GFLAGS_DLL_DECL void HandleCommandLineHelpFlags(); // in gflags_reporting.cc + +// Allow command line reparsing. Disables the error normally +// generated when an unknown flag is found, since it may be found in a +// later parse. Thread-hostile; meant to be called before any threads +// are spawned. +extern GFLAGS_DLL_DECL void AllowCommandLineReparsing(); + +// Reparse the flags that have not yet been recognized. Only flags +// registered since the last parse will be recognized. Any flag value +// must be provided as part of the argument using "=", not as a +// separate command line argument that follows the flag argument. +// Intended for handling flags from dynamically loaded libraries, +// since their flags are not registered until they are loaded. +extern GFLAGS_DLL_DECL void ReparseCommandLineNonHelpFlags(); + +// Clean up memory allocated by flags. This is only needed to reduce +// the quantity of "potentially leaked" reports emitted by memory +// debugging tools such as valgrind. It is not required for normal +// operation, or for the google perftools heap-checker. It must only +// be called when the process is about to exit, and all threads that +// might access flags are quiescent. Referencing flags after this is +// called will have unexpected consequences. This is not safe to run +// when multiple threads might be running: the function is +// thread-hostile. +extern GFLAGS_DLL_DECL void ShutDownCommandLineFlags(); + + +// -------------------------------------------------------------------- +// Now come the command line flag declaration/definition macros that +// will actually be used. They're kind of hairy. A major reason +// for this is initialization: we want people to be able to access +// variables in global constructors and have that not crash, even if +// their global constructor runs before the global constructor here. +// (Obviously, we can't guarantee the flags will have the correct +// default value in that case, but at least accessing them is safe.) +// The only way to do that is have flags point to a static buffer. +// So we make one, using a union to ensure proper alignment, and +// then use placement-new to actually set up the flag with the +// correct default value. In the same vein, we have to worry about +// flag access in global destructors, so FlagRegisterer has to be +// careful never to destroy the flag-values it constructs. +// +// Note that when we define a flag variable FLAGS_, we also +// preemptively define a junk variable, FLAGS_no. This is to +// cause a link-time error if someone tries to define 2 flags with +// names like "logging" and "nologging". We do this because a bool +// flag FLAG can be set from the command line to true with a "-FLAG" +// argument, and to false with a "-noFLAG" argument, and so this can +// potentially avert confusion. +// +// We also put flags into their own namespace. It is purposefully +// named in an opaque way that people should have trouble typing +// directly. The idea is that DEFINE puts the flag in the weird +// namespace, and DECLARE imports the flag from there into the current +// namespace. The net result is to force people to use DECLARE to get +// access to a flag, rather than saying "extern GFLAGS_DLL_DECL bool FLAGS_whatever;" +// or some such instead. We want this so we can put extra +// functionality (like sanity-checking) in DECLARE if we want, and +// make sure it is picked up everywhere. +// +// We also put the type of the variable in the namespace, so that +// people can't DECLARE_int32 something that they DEFINE_bool'd +// elsewhere. + +class GFLAGS_DLL_DECL FlagRegisterer { + public: + // We instantiate this template ctor for all supported types, + // so it is possible to place implementation of the FlagRegisterer ctor in + // .cc file. + // Calling this constructor with unsupported type will produce linker error. + template + FlagRegisterer(const char* name, + const char* help, const char* filename, + FlagType* current_storage, FlagType* defvalue_storage); +}; + +// Force compiler to not generate code for the given template specialization. +#if defined(_MSC_VER) && _MSC_VER < 1800 // Visual Studio 2013 version 12.0 + #define GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(type) +#else + #define GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(type) \ + extern template GFLAGS_DLL_DECL FlagRegisterer::FlagRegisterer( \ + const char* name, const char* help, const char* filename, \ + type* current_storage, type* defvalue_storage) +#endif + +// Do this for all supported flag types. +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(bool); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(int32); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(uint32); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(int64); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(uint64); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(double); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(std::string); + +#undef GFLAGS_DECLARE_FLAG_REGISTERER_CTOR + +// If your application #defines STRIP_FLAG_HELP to a non-zero value +// before #including this file, we remove the help message from the +// binary file. This can reduce the size of the resulting binary +// somewhat, and may also be useful for security reasons. + +extern GFLAGS_DLL_DECL const char kStrippedFlagHelp[]; + + +} // namespace GFLAGS_NAMESPACE + + +#ifndef SWIG // In swig, ignore the main flag declarations + +#if defined(STRIP_FLAG_HELP) && STRIP_FLAG_HELP > 0 +// Need this construct to avoid the 'defined but not used' warning. +#define MAYBE_STRIPPED_HELP(txt) \ + (false ? (txt) : GFLAGS_NAMESPACE::kStrippedFlagHelp) +#else +#define MAYBE_STRIPPED_HELP(txt) txt +#endif + +// Each command-line flag has two variables associated with it: one +// with the current value, and one with the default value. However, +// we have a third variable, which is where value is assigned; it's a +// constant. This guarantees that FLAG_##value is initialized at +// static initialization time (e.g. before program-start) rather than +// than global construction time (which is after program-start but +// before main), at least when 'value' is a compile-time constant. We +// use a small trick for the "default value" variable, and call it +// FLAGS_no. This serves the second purpose of assuring a +// compile error if someone tries to define a flag named no +// which is illegal (--foo and --nofoo both affect the "foo" flag). +#define DEFINE_VARIABLE(type, shorttype, name, value, help) \ + namespace fL##shorttype { \ + static const type FLAGS_nono##name = value; \ + /* We always want to export defined variables, dll or no */ \ + GFLAGS_DLL_DEFINE_FLAG type FLAGS_##name = FLAGS_nono##name; \ + static type FLAGS_no##name = FLAGS_nono##name; \ + static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ + #name, MAYBE_STRIPPED_HELP(help), __FILE__, \ + &FLAGS_##name, &FLAGS_no##name); \ + } \ + using fL##shorttype::FLAGS_##name + +// For DEFINE_bool, we want to do the extra check that the passed-in +// value is actually a bool, and not a string or something that can be +// coerced to a bool. These declarations (no definition needed!) will +// help us do that, and never evaluate From, which is important. +// We'll use 'sizeof(IsBool(val))' to distinguish. This code requires +// that the compiler have different sizes for bool & double. Since +// this is not guaranteed by the standard, we check it with a +// COMPILE_ASSERT. +namespace fLB { +struct CompileAssert {}; +typedef CompileAssert expected_sizeof_double_neq_sizeof_bool[ + (sizeof(double) != sizeof(bool)) ? 1 : -1]; +template double GFLAGS_DLL_DECL IsBoolFlag(const From& from); +GFLAGS_DLL_DECL bool IsBoolFlag(bool from); +} // namespace fLB + +// Here are the actual DEFINE_*-macros. The respective DECLARE_*-macros +// are in a separate include, gflags_declare.h, for reducing +// the physical transitive size for DECLARE use. +#define DEFINE_bool(name, val, txt) \ + namespace fLB { \ + typedef ::fLB::CompileAssert FLAG_##name##_value_is_not_a_bool[ \ + (sizeof(::fLB::IsBoolFlag(val)) != sizeof(double))? 1: -1]; \ + } \ + DEFINE_VARIABLE(bool, B, name, val, txt) + +#define DEFINE_int32(name, val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::int32, I, \ + name, val, txt) + +#define DEFINE_uint32(name,val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::uint32, U, \ + name, val, txt) + +#define DEFINE_int64(name, val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::int64, I64, \ + name, val, txt) + +#define DEFINE_uint64(name,val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::uint64, U64, \ + name, val, txt) + +#define DEFINE_double(name, val, txt) \ + DEFINE_VARIABLE(double, D, name, val, txt) + +// Strings are trickier, because they're not a POD, so we can't +// construct them at static-initialization time (instead they get +// constructed at global-constructor time, which is much later). To +// try to avoid crashes in that case, we use a char buffer to store +// the string, which we can static-initialize, and then placement-new +// into it later. It's not perfect, but the best we can do. + +namespace fLS { + +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + const char *value) { + return new(stringspot) clstring(value); +} +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + const clstring &value) { + return new(stringspot) clstring(value); +} +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + int value); + +// Auxiliary class used to explicitly call destructor of string objects +// allocated using placement new during static program deinitialization. +// The destructor MUST be an inline function such that the explicit +// destruction occurs in the same compilation unit as the placement new. +class StringFlagDestructor { + void *current_storage_; + void *defvalue_storage_; + +public: + + StringFlagDestructor(void *current, void *defvalue) + : current_storage_(current), defvalue_storage_(defvalue) {} + + ~StringFlagDestructor() { + reinterpret_cast(current_storage_ )->~clstring(); + reinterpret_cast(defvalue_storage_)->~clstring(); + } +}; + +} // namespace fLS + +// We need to define a var named FLAGS_no##name so people don't define +// --string and --nostring. And we need a temporary place to put val +// so we don't have to evaluate it twice. Two great needs that go +// great together! +// The weird 'using' + 'extern' inside the fLS namespace is to work around +// an unknown compiler bug/issue with the gcc 4.2.1 on SUSE 10. See +// http://code.google.com/p/google-gflags/issues/detail?id=20 +#define DEFINE_string(name, val, txt) \ + namespace fLS { \ + using ::fLS::clstring; \ + using ::fLS::StringFlagDestructor; \ + static union { void* align; char s[sizeof(clstring)]; } s_##name[2]; \ + clstring* const FLAGS_no##name = ::fLS:: \ + dont_pass0toDEFINE_string(s_##name[0].s, \ + val); \ + static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ + #name, MAYBE_STRIPPED_HELP(txt), __FILE__, \ + FLAGS_no##name, new (s_##name[1].s) clstring(*FLAGS_no##name)); \ + static StringFlagDestructor d_##name(s_##name[0].s, s_##name[1].s); \ + extern GFLAGS_DLL_DEFINE_FLAG clstring& FLAGS_##name; \ + using fLS::FLAGS_##name; \ + clstring& FLAGS_##name = *FLAGS_no##name; \ + } \ + using fLS::FLAGS_##name + +#endif // SWIG + + + + + +#endif // GFLAGS_GFLAGS_H_ diff --git a/projects/llm_framework/include/gflags/gflags_completions.h b/projects/llm_framework/include/gflags/gflags_completions.h new file mode 100644 index 00000000..2fa0db6d --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags_completions.h @@ -0,0 +1,121 @@ +// Copyright (c) 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// --- + +// +// Implement helpful bash-style command line flag completions +// +// ** Functional API: +// HandleCommandLineCompletions() should be called early during +// program startup, but after command line flag code has been +// initialized, such as the beginning of HandleCommandLineHelpFlags(). +// It checks the value of the flag --tab_completion_word. If this +// flag is empty, nothing happens here. If it contains a string, +// however, then HandleCommandLineCompletions() will hijack the +// process, attempting to identify the intention behind this +// completion. Regardless of the outcome of this deduction, the +// process will be terminated, similar to --helpshort flag +// handling. +// +// ** Overview of Bash completions: +// Bash can be told to programatically determine completions for the +// current 'cursor word'. It does this by (in this case) invoking a +// command with some additional arguments identifying the command +// being executed, the word being completed, and the previous word +// (if any). Bash then expects a sequence of output lines to be +// printed to stdout. If these lines all contain a common prefix +// longer than the cursor word, bash will replace the cursor word +// with that common prefix, and display nothing. If there isn't such +// a common prefix, bash will display the lines in pages using 'more'. +// +// ** Strategy taken for command line completions: +// If we can deduce either the exact flag intended, or a common flag +// prefix, we'll output exactly that. Otherwise, if information +// must be displayed to the user, we'll take the opportunity to add +// some helpful information beyond just the flag name (specifically, +// we'll include the default flag value and as much of the flag's +// description as can fit on a single terminal line width, as specified +// by the flag --tab_completion_columns). Furthermore, we'll try to +// make bash order the output such that the most useful or relevent +// flags are the most likely to be shown at the top. +// +// ** Additional features: +// To assist in finding that one really useful flag, substring matching +// was implemented. Before pressing a to get completion for the +// current word, you can append one or more '?' to the flag to do +// substring matching. Here's the semantics: +// --foo Show me all flags with names prefixed by 'foo' +// --foo? Show me all flags with 'foo' somewhere in the name +// --foo?? Same as prior case, but also search in module +// definition path for 'foo' +// --foo??? Same as prior case, but also search in flag +// descriptions for 'foo' +// Finally, we'll trim the output to a relatively small number of +// flags to keep bash quiet about the verbosity of output. If one +// really wanted to see all possible matches, appending a '+' to the +// search word will force the exhaustive list of matches to be printed. +// +// ** How to have bash accept completions from a binary: +// Bash requires that it be informed about each command that programmatic +// completion should be enabled for. Example addition to a .bashrc +// file would be (your path to gflags_completions.sh file may differ): + +/* +$ complete -o bashdefault -o default -o nospace -C \ + '/home/build/eng/bash/bash_completions.sh --tab_completion_columns $COLUMNS' \ + time env binary_name another_binary [...] +*/ + +// This would allow the following to work: +// $ /path/to/binary_name --vmodule +// Or: +// $ ./bin/path/another_binary --gfs_u +// (etc) +// +// Sadly, it appears that bash gives no easy way to force this behavior for +// all commands. That's where the "time" in the above example comes in. +// If you haven't specifically added a command to the list of completion +// supported commands, you can still get completions by prefixing the +// entire command with "env". +// $ env /some/brand/new/binary --vmod +// Assuming that "binary" is a newly compiled binary, this should still +// produce the expected completion output. + + +#ifndef GFLAGS_COMPLETIONS_H_ +#define GFLAGS_COMPLETIONS_H_ + +namespace gflags { + +extern void HandleCommandLineCompletions(void); + +} + +#endif // GFLAGS_COMPLETIONS_H_ diff --git a/projects/llm_framework/include/gflags/gflags_declare.h b/projects/llm_framework/include/gflags/gflags_declare.h new file mode 100644 index 00000000..69cf1129 --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags_declare.h @@ -0,0 +1,156 @@ +// Copyright (c) 1999, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// Revamped and reorganized by Craig Silverstein +// +// This is the file that should be included by any file which declares +// command line flag. + +#ifndef GFLAGS_DECLARE_H_ +#define GFLAGS_DECLARE_H_ + + +// --------------------------------------------------------------------------- +// Namespace of gflags library symbols. +#define GFLAGS_NAMESPACE gflags + +// --------------------------------------------------------------------------- +// Windows DLL import/export. + +// Whether gflags library is a DLL. +// +// Set to 1 by default when the shared gflags library was built on Windows. +// Must be overwritten when this header file is used with the optionally also +// built static library instead; set by CMake's INTERFACE_COMPILE_DEFINITIONS. +#ifndef GFLAGS_IS_A_DLL +# define GFLAGS_IS_A_DLL 1 +#endif + +// We always want to import the symbols of the gflags library. +#ifndef GFLAGS_DLL_DECL +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DECL __declspec(dllimport) +# elif defined(__GNUC__) && __GNUC__ >= 4 +# define GFLAGS_DLL_DECL __attribute__((visibility("default"))) +# else +# define GFLAGS_DLL_DECL +# endif +#endif + +// We always want to import variables declared in user code. +#ifndef GFLAGS_DLL_DECLARE_FLAG +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DECLARE_FLAG __declspec(dllimport) +# elif defined(__GNUC__) && __GNUC__ >= 4 +# define GFLAGS_DLL_DECLARE_FLAG __attribute__((visibility("default"))) +# else +# define GFLAGS_DLL_DECLARE_FLAG +# endif +#endif + +// --------------------------------------------------------------------------- +// Flag types +#include +#if 1 +# include // the normal place uint32_t is defined +#elif 1 +# include // the normal place u_int32_t is defined +#elif 1 +# include // a third place for uint32_t or u_int32_t +#endif + +namespace GFLAGS_NAMESPACE { + +#if 1 // C99 +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#elif 0 // BSD +typedef int32_t int32; +typedef u_int32_t uint32; +typedef int64_t int64; +typedef u_int64_t uint64; +#elif 0 // Windows +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#else +# error Do not know how to define a 32-bit integer quantity on your system +#endif + +} // namespace GFLAGS_NAMESPACE + + +namespace fLS { + +// The meaning of "string" might be different between now and when the +// macros below get invoked (e.g., if someone is experimenting with +// other string implementations that get defined after this file is +// included). Save the current meaning now and use it in the macros. +typedef std::string clstring; + +} // namespace fLS + + +#define DECLARE_VARIABLE(type, shorttype, name) \ + /* We always want to import declared variables, dll or no */ \ + namespace fL##shorttype { extern GFLAGS_DLL_DECLARE_FLAG type FLAGS_##name; } \ + using fL##shorttype::FLAGS_##name + +#define DECLARE_bool(name) \ + DECLARE_VARIABLE(bool, B, name) + +#define DECLARE_int32(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::int32, I, name) + +#define DECLARE_uint32(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::uint32, U, name) + +#define DECLARE_int64(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::int64, I64, name) + +#define DECLARE_uint64(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::uint64, U64, name) + +#define DECLARE_double(name) \ + DECLARE_VARIABLE(double, D, name) + +#define DECLARE_string(name) \ + /* We always want to import declared variables, dll or no */ \ + namespace fLS { \ + extern GFLAGS_DLL_DECLARE_FLAG ::fLS::clstring& FLAGS_##name; \ + } \ + using fLS::FLAGS_##name + + +#endif // GFLAGS_DECLARE_H_ diff --git a/projects/llm_framework/include/glog/log_severity.h b/projects/llm_framework/include/glog/log_severity.h new file mode 100644 index 00000000..99945a42 --- /dev/null +++ b/projects/llm_framework/include/glog/log_severity.h @@ -0,0 +1,92 @@ +// Copyright (c) 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef BASE_LOG_SEVERITY_H__ +#define BASE_LOG_SEVERITY_H__ + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// Variables of type LogSeverity are widely taken to lie in the range +// [0, NUM_SEVERITIES-1]. Be careful to preserve this assumption if +// you ever need to change their values or add a new severity. +typedef int LogSeverity; + +const int GLOG_INFO = 0, GLOG_WARNING = 1, GLOG_ERROR = 2, GLOG_FATAL = 3, + NUM_SEVERITIES = 4; +#ifndef GLOG_NO_ABBREVIATED_SEVERITIES +# ifdef ERROR +# error ERROR macro is defined. Define GLOG_NO_ABBREVIATED_SEVERITIES before including logging.h. See the document for detail. +# endif +const int INFO = GLOG_INFO, WARNING = GLOG_WARNING, + ERROR = GLOG_ERROR, FATAL = GLOG_FATAL; +#endif + +// DFATAL is FATAL in debug mode, ERROR in normal mode +#ifdef NDEBUG +#define DFATAL_LEVEL ERROR +#else +#define DFATAL_LEVEL FATAL +#endif + +extern GOOGLE_GLOG_DLL_DECL const char* const LogSeverityNames[NUM_SEVERITIES]; + +// NDEBUG usage helpers related to (RAW_)DCHECK: +// +// DEBUG_MODE is for small !NDEBUG uses like +// if (DEBUG_MODE) foo.CheckThatFoo(); +// instead of substantially more verbose +// #ifndef NDEBUG +// foo.CheckThatFoo(); +// #endif +// +// IF_DEBUG_MODE is for small !NDEBUG uses like +// IF_DEBUG_MODE( string error; ) +// DCHECK(Foo(&error)) << error; +// instead of substantially more verbose +// #ifndef NDEBUG +// string error; +// DCHECK(Foo(&error)) << error; +// #endif +// +#ifdef NDEBUG +enum { DEBUG_MODE = 0 }; +#define IF_DEBUG_MODE(x) +#else +enum { DEBUG_MODE = 1 }; +#define IF_DEBUG_MODE(x) x +#endif + +#endif // BASE_LOG_SEVERITY_H__ diff --git a/projects/llm_framework/include/glog/logging.h b/projects/llm_framework/include/glog/logging.h new file mode 100644 index 00000000..4cd247e1 --- /dev/null +++ b/projects/llm_framework/include/glog/logging.h @@ -0,0 +1,1662 @@ +// Copyright (c) 1999, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Ray Sidney +// +// This file contains #include information about logging-related stuff. +// Pretty much everybody needs to #include this file so that they can +// log various happenings. +// +#ifndef _LOGGING_H_ +#define _LOGGING_H_ + +#include +#include +#include +#include +#include +#include +#include +#if 1 +# include +#endif +#include + +#if defined(_MSC_VER) +#define GLOG_MSVC_PUSH_DISABLE_WARNING(n) __pragma(warning(push)) \ + __pragma(warning(disable:n)) +#define GLOG_MSVC_POP_WARNING() __pragma(warning(pop)) +#else +#define GLOG_MSVC_PUSH_DISABLE_WARNING(n) +#define GLOG_MSVC_POP_WARNING() +#endif + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// We care a lot about number of bits things take up. Unfortunately, +// systems define their bit-specific ints in a lot of different ways. +// We use our own way, and have a typedef to get there. +// Note: these commands below may look like "#if 1" or "#if 0", but +// that's because they were constructed that way at ./configure time. +// Look at logging.h.in to see how they're calculated (based on your config). +#if 1 +#include // the normal place uint16_t is defined +#endif +#if 1 +#include // the normal place u_int16_t is defined +#endif +#if 1 +#include // a third place for uint16_t or u_int16_t +#endif + +#if 0 +#include +#endif + +namespace google { + +#if 1 // the C99 format +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#elif 1 // the BSD format +typedef int32_t int32; +typedef u_int32_t uint32; +typedef int64_t int64; +typedef u_int64_t uint64; +#elif 0 // the windows (vc7) format +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#else +#error Do not know how to define a 32-bit integer quantity on your system +#endif + +} + +// The global value of GOOGLE_STRIP_LOG. All the messages logged to +// LOG(XXX) with severity less than GOOGLE_STRIP_LOG will not be displayed. +// If it can be determined at compile time that the message will not be +// printed, the statement will be compiled out. +// +// Example: to strip out all INFO and WARNING messages, use the value +// of 2 below. To make an exception for WARNING messages from a single +// file, add "#define GOOGLE_STRIP_LOG 1" to that file _before_ including +// base/logging.h +#ifndef GOOGLE_STRIP_LOG +#define GOOGLE_STRIP_LOG 0 +#endif + +// GCC can be told that a certain branch is not likely to be taken (for +// instance, a CHECK failure), and use that information in static analysis. +// Giving it this information can help it optimize for the common case in +// the absence of better information (ie. -fprofile-arcs). +// +#ifndef GOOGLE_PREDICT_BRANCH_NOT_TAKEN +#if 1 +#define GOOGLE_PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0)) +#else +#define GOOGLE_PREDICT_BRANCH_NOT_TAKEN(x) x +#endif +#endif + +#ifndef GOOGLE_PREDICT_FALSE +#if 1 +#define GOOGLE_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#else +#define GOOGLE_PREDICT_FALSE(x) x +#endif +#endif + +#ifndef GOOGLE_PREDICT_TRUE +#if 1 +#define GOOGLE_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define GOOGLE_PREDICT_TRUE(x) x +#endif +#endif + + +// Make a bunch of macros for logging. The way to log things is to stream +// things to LOG(). E.g., +// +// LOG(INFO) << "Found " << num_cookies << " cookies"; +// +// You can capture log messages in a string, rather than reporting them +// immediately: +// +// vector errors; +// LOG_STRING(ERROR, &errors) << "Couldn't parse cookie #" << cookie_num; +// +// This pushes back the new error onto 'errors'; if given a NULL pointer, +// it reports the error via LOG(ERROR). +// +// You can also do conditional logging: +// +// LOG_IF(INFO, num_cookies > 10) << "Got lots of cookies"; +// +// You can also do occasional logging (log every n'th occurrence of an +// event): +// +// LOG_EVERY_N(INFO, 10) << "Got the " << google::COUNTER << "th cookie"; +// +// The above will cause log messages to be output on the 1st, 11th, 21st, ... +// times it is executed. Note that the special google::COUNTER value is used +// to identify which repetition is happening. +// +// You can also do occasional conditional logging (log every n'th +// occurrence of an event, when condition is satisfied): +// +// LOG_IF_EVERY_N(INFO, (size > 1024), 10) << "Got the " << google::COUNTER +// << "th big cookie"; +// +// You can log messages the first N times your code executes a line. E.g. +// +// LOG_FIRST_N(INFO, 20) << "Got the " << google::COUNTER << "th cookie"; +// +// Outputs log messages for the first 20 times it is executed. +// +// Analogous SYSLOG, SYSLOG_IF, and SYSLOG_EVERY_N macros are available. +// These log to syslog as well as to the normal logs. If you use these at +// all, you need to be aware that syslog can drastically reduce performance, +// especially if it is configured for remote logging! Don't use these +// unless you fully understand this and have a concrete need to use them. +// Even then, try to minimize your use of them. +// +// There are also "debug mode" logging macros like the ones above: +// +// DLOG(INFO) << "Found cookies"; +// +// DLOG_IF(INFO, num_cookies > 10) << "Got lots of cookies"; +// +// DLOG_EVERY_N(INFO, 10) << "Got the " << google::COUNTER << "th cookie"; +// +// All "debug mode" logging is compiled away to nothing for non-debug mode +// compiles. +// +// We also have +// +// LOG_ASSERT(assertion); +// DLOG_ASSERT(assertion); +// +// which is syntactic sugar for {,D}LOG_IF(FATAL, assert fails) << assertion; +// +// There are "verbose level" logging macros. They look like +// +// VLOG(1) << "I'm printed when you run the program with --v=1 or more"; +// VLOG(2) << "I'm printed when you run the program with --v=2 or more"; +// +// These always log at the INFO log level (when they log at all). +// The verbose logging can also be turned on module-by-module. For instance, +// --vmodule=mapreduce=2,file=1,gfs*=3 --v=0 +// will cause: +// a. VLOG(2) and lower messages to be printed from mapreduce.{h,cc} +// b. VLOG(1) and lower messages to be printed from file.{h,cc} +// c. VLOG(3) and lower messages to be printed from files prefixed with "gfs" +// d. VLOG(0) and lower messages to be printed from elsewhere +// +// The wildcarding functionality shown by (c) supports both '*' (match +// 0 or more characters) and '?' (match any single character) wildcards. +// +// There's also VLOG_IS_ON(n) "verbose level" condition macro. To be used as +// +// if (VLOG_IS_ON(2)) { +// // do some logging preparation and logging +// // that can't be accomplished with just VLOG(2) << ...; +// } +// +// There are also VLOG_IF, VLOG_EVERY_N and VLOG_IF_EVERY_N "verbose level" +// condition macros for sample cases, when some extra computation and +// preparation for logs is not needed. +// VLOG_IF(1, (size > 1024)) +// << "I'm printed when size is more than 1024 and when you run the " +// "program with --v=1 or more"; +// VLOG_EVERY_N(1, 10) +// << "I'm printed every 10th occurrence, and when you run the program " +// "with --v=1 or more. Present occurence is " << google::COUNTER; +// VLOG_IF_EVERY_N(1, (size > 1024), 10) +// << "I'm printed on every 10th occurence of case when size is more " +// " than 1024, when you run the program with --v=1 or more. "; +// "Present occurence is " << google::COUNTER; +// +// The supported severity levels for macros that allow you to specify one +// are (in increasing order of severity) INFO, WARNING, ERROR, and FATAL. +// Note that messages of a given severity are logged not only in the +// logfile for that severity, but also in all logfiles of lower severity. +// E.g., a message of severity FATAL will be logged to the logfiles of +// severity FATAL, ERROR, WARNING, and INFO. +// +// There is also the special severity of DFATAL, which logs FATAL in +// debug mode, ERROR in normal mode. +// +// Very important: logging a message at the FATAL severity level causes +// the program to terminate (after the message is logged). +// +// Unless otherwise specified, logs will be written to the filename +// "...log..", followed +// by the date, time, and pid (you can't prevent the date, time, and pid +// from being in the filename). +// +// The logging code takes two flags: +// --v=# set the verbose level +// --logtostderr log all the messages to stderr instead of to logfiles + +// LOG LINE PREFIX FORMAT +// +// Log lines have this form: +// +// Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg... +// +// where the fields are defined as follows: +// +// L A single character, representing the log level +// (eg 'I' for INFO) +// mm The month (zero padded; ie May is '05') +// dd The day (zero padded) +// hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds +// threadid The space-padded thread ID as returned by GetTID() +// (this matches the PID on Linux) +// file The file name +// line The line number +// msg The user-supplied message +// +// Example: +// +// I1103 11:57:31.739339 24395 google.cc:2341] Command line: ./some_prog +// I1103 11:57:31.739403 24395 google.cc:2342] Process id 24395 +// +// NOTE: although the microseconds are useful for comparing events on +// a single machine, clocks on different machines may not be well +// synchronized. Hence, use caution when comparing the low bits of +// timestamps from different machines. + +#ifndef DECLARE_VARIABLE +#define MUST_UNDEF_GFLAGS_DECLARE_MACROS +#define DECLARE_VARIABLE(type, shorttype, name, tn) \ + namespace fL##shorttype { \ + extern GOOGLE_GLOG_DLL_DECL type FLAGS_##name; \ + } \ + using fL##shorttype::FLAGS_##name + +// bool specialization +#define DECLARE_bool(name) \ + DECLARE_VARIABLE(bool, B, name, bool) + +// int32 specialization +#define DECLARE_int32(name) \ + DECLARE_VARIABLE(google::int32, I, name, int32) + +// Special case for string, because we have to specify the namespace +// std::string, which doesn't play nicely with our FLAG__namespace hackery. +#define DECLARE_string(name) \ + namespace fLS { \ + extern GOOGLE_GLOG_DLL_DECL std::string& FLAGS_##name; \ + } \ + using fLS::FLAGS_##name +#endif + +// Set whether log messages go to stderr instead of logfiles +DECLARE_bool(logtostderr); + +// Set whether log messages go to stderr in addition to logfiles. +DECLARE_bool(alsologtostderr); + +// Set color messages logged to stderr (if supported by terminal). +DECLARE_bool(colorlogtostderr); + +// Log messages at a level >= this flag are automatically sent to +// stderr in addition to log files. +DECLARE_int32(stderrthreshold); + +// Set whether the log prefix should be prepended to each line of output. +DECLARE_bool(log_prefix); + +// Log messages at a level <= this flag are buffered. +// Log messages at a higher level are flushed immediately. +DECLARE_int32(logbuflevel); + +// Sets the maximum number of seconds which logs may be buffered for. +DECLARE_int32(logbufsecs); + +// Log suppression level: messages logged at a lower level than this +// are suppressed. +DECLARE_int32(minloglevel); + +// If specified, logfiles are written into this directory instead of the +// default logging directory. +DECLARE_string(log_dir); + +// Set the log file mode. +DECLARE_int32(logfile_mode); + +// Sets the path of the directory into which to put additional links +// to the log files. +DECLARE_string(log_link); + +DECLARE_int32(v); // in vlog_is_on.cc + +// Sets the maximum log file size (in MB). +DECLARE_int32(max_log_size); + +// Sets whether to avoid logging to the disk if the disk is full. +DECLARE_bool(stop_logging_if_full_disk); + +#ifdef MUST_UNDEF_GFLAGS_DECLARE_MACROS +#undef MUST_UNDEF_GFLAGS_DECLARE_MACROS +#undef DECLARE_VARIABLE +#undef DECLARE_bool +#undef DECLARE_int32 +#undef DECLARE_string +#endif + +// Log messages below the GOOGLE_STRIP_LOG level will be compiled away for +// security reasons. See LOG(severtiy) below. + +// A few definitions of macros that don't generate much code. Since +// LOG(INFO) and its ilk are used all over our code, it's +// better to have compact code for these operations. + +#if GOOGLE_STRIP_LOG == 0 +#define COMPACT_GOOGLE_LOG_INFO google::LogMessage( \ + __FILE__, __LINE__) +#define LOG_TO_STRING_INFO(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_INFO, message) +#else +#define COMPACT_GOOGLE_LOG_INFO google::NullStream() +#define LOG_TO_STRING_INFO(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 1 +#define COMPACT_GOOGLE_LOG_WARNING google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_WARNING) +#define LOG_TO_STRING_WARNING(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_WARNING, message) +#else +#define COMPACT_GOOGLE_LOG_WARNING google::NullStream() +#define LOG_TO_STRING_WARNING(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 2 +#define COMPACT_GOOGLE_LOG_ERROR google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ERROR) +#define LOG_TO_STRING_ERROR(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ERROR, message) +#else +#define COMPACT_GOOGLE_LOG_ERROR google::NullStream() +#define LOG_TO_STRING_ERROR(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 3 +#define COMPACT_GOOGLE_LOG_FATAL google::LogMessageFatal( \ + __FILE__, __LINE__) +#define LOG_TO_STRING_FATAL(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_FATAL, message) +#else +#define COMPACT_GOOGLE_LOG_FATAL google::NullStreamFatal() +#define LOG_TO_STRING_FATAL(message) google::NullStreamFatal() +#endif + +#if defined(NDEBUG) && !defined(DCHECK_ALWAYS_ON) +#define DCHECK_IS_ON() 0 +#else +#define DCHECK_IS_ON() 1 +#endif + +// For DFATAL, we want to use LogMessage (as opposed to +// LogMessageFatal), to be consistent with the original behavior. +#if !DCHECK_IS_ON() +#define COMPACT_GOOGLE_LOG_DFATAL COMPACT_GOOGLE_LOG_ERROR +#elif GOOGLE_STRIP_LOG <= 3 +#define COMPACT_GOOGLE_LOG_DFATAL google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_FATAL) +#else +#define COMPACT_GOOGLE_LOG_DFATAL google::NullStreamFatal() +#endif + +#define GOOGLE_LOG_INFO(counter) google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO, counter, &google::LogMessage::SendToLog) +#define SYSLOG_INFO(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_WARNING(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_WARNING(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_ERROR(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_ERROR(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_FATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_FATAL, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_FATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_FATAL, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_DFATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::DFATAL_LEVEL, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_DFATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::DFATAL_LEVEL, counter, \ + &google::LogMessage::SendToSyslogAndLog) + +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__) || defined(__CYGWIN32__) +// A very useful logging macro to log windows errors: +#define LOG_SYSRESULT(result) \ + if (FAILED(HRESULT_FROM_WIN32(result))) { \ + LPSTR message = NULL; \ + LPSTR msg = reinterpret_cast(&message); \ + DWORD message_length = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | \ + FORMAT_MESSAGE_FROM_SYSTEM, \ + 0, result, 0, msg, 100, NULL); \ + if (message_length > 0) { \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, 0, \ + &google::LogMessage::SendToLog).stream() \ + << reinterpret_cast(message); \ + LocalFree(message); \ + } \ + } +#endif + +// We use the preprocessor's merging operator, "##", so that, e.g., +// LOG(INFO) becomes the token GOOGLE_LOG_INFO. There's some funny +// subtle difference between ostream member streaming functions (e.g., +// ostream::operator<<(int) and ostream non-member streaming functions +// (e.g., ::operator<<(ostream&, string&): it turns out that it's +// impossible to stream something like a string directly to an unnamed +// ostream. We employ a neat hack by calling the stream() member +// function of LogMessage which seems to avoid the problem. +#define LOG(severity) COMPACT_GOOGLE_LOG_ ## severity.stream() +#define SYSLOG(severity) SYSLOG_ ## severity(0).stream() + +namespace google { + +// They need the definitions of integer types. +#include "glog/log_severity.h" +#include "glog/vlog_is_on.h" + +// Initialize google's logging library. You will see the program name +// specified by argv0 in log outputs. +GOOGLE_GLOG_DLL_DECL void InitGoogleLogging(const char* argv0); + +// Shutdown google's logging library. +GOOGLE_GLOG_DLL_DECL void ShutdownGoogleLogging(); + +// Install a function which will be called after LOG(FATAL). +GOOGLE_GLOG_DLL_DECL void InstallFailureFunction(void (*fail_func)()); + +class LogSink; // defined below + +// If a non-NULL sink pointer is given, we push this message to that sink. +// For LOG_TO_SINK we then do normal LOG(severity) logging as well. +// This is useful for capturing messages and passing/storing them +// somewhere more specific than the global log of the process. +// Argument types: +// LogSink* sink; +// LogSeverity severity; +// The cast is to disambiguate NULL arguments. +#define LOG_TO_SINK(sink, severity) \ + google::LogMessage( \ + __FILE__, __LINE__, \ + google::GLOG_ ## severity, \ + static_cast(sink), true).stream() +#define LOG_TO_SINK_BUT_NOT_TO_LOGFILE(sink, severity) \ + google::LogMessage( \ + __FILE__, __LINE__, \ + google::GLOG_ ## severity, \ + static_cast(sink), false).stream() + +// If a non-NULL string pointer is given, we write this message to that string. +// We then do normal LOG(severity) logging as well. +// This is useful for capturing messages and storing them somewhere more +// specific than the global log of the process. +// Argument types: +// string* message; +// LogSeverity severity; +// The cast is to disambiguate NULL arguments. +// NOTE: LOG(severity) expands to LogMessage().stream() for the specified +// severity. +#define LOG_TO_STRING(severity, message) \ + LOG_TO_STRING_##severity(static_cast(message)).stream() + +// If a non-NULL pointer is given, we push the message onto the end +// of a vector of strings; otherwise, we report it with LOG(severity). +// This is handy for capturing messages and perhaps passing them back +// to the caller, rather than reporting them immediately. +// Argument types: +// LogSeverity severity; +// vector *outvec; +// The cast is to disambiguate NULL arguments. +#define LOG_STRING(severity, outvec) \ + LOG_TO_STRING_##severity(static_cast*>(outvec)).stream() + +#define LOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & LOG(severity) +#define SYSLOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & SYSLOG(severity) + +#define LOG_ASSERT(condition) \ + LOG_IF(FATAL, !(condition)) << "Assert failed: " #condition +#define SYSLOG_ASSERT(condition) \ + SYSLOG_IF(FATAL, !(condition)) << "Assert failed: " #condition + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by DCHECK_IS_ON(), so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + LOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!(condition))) \ + << "Check failed: " #condition " " + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is NULL. +struct CheckOpString { + CheckOpString(std::string* str) : str_(str) { } + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + operator bool() const { + return GOOGLE_PREDICT_BRANCH_NOT_TAKEN(str_ != NULL); + } + std::string* str_; +}; + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template +inline const T& GetReferenceableValue(const T& t) { return t; } +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline short GetReferenceableValue(short t) { return t; } +inline unsigned short GetReferenceableValue(unsigned short t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline long GetReferenceableValue(long t) { return t; } +inline unsigned long GetReferenceableValue(unsigned long t) { return t; } +inline long long GetReferenceableValue(long long t) { return t; } +inline unsigned long long GetReferenceableValue(unsigned long long t) { + return t; +} + +// This is a dummy class to define the following operator. +struct DummyClassToDefineOperator {}; + +} + +// Define global operator<< to declare using ::operator<<. +// This declaration will allow use to use CHECK macros for user +// defined classes which have operator<< (e.g., stl_logging.h). +inline std::ostream& operator<<( + std::ostream& out, const google::DummyClassToDefineOperator&) { + return out; +} + +namespace google { + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream* os, const T& v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const char& v); +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const signed char& v); +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); + +// Build the error message string. Specify no inlining for code size. +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) + __attribute__((noinline)); + +namespace base { +namespace internal { + +// If "s" is less than base_logging::INFO, returns base_logging::INFO. +// If "s" is greater than base_logging::FATAL, returns +// base_logging::ERROR. Otherwise, returns "s". +LogSeverity NormalizeSeverity(LogSeverity s); + +} // namespace internal + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class GOOGLE_GLOG_DLL_DECL CheckOpMessageBuilder { + public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char *exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream* ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream* ForVar2(); + // Get the result (inserts the closing ")"). + std::string* NewString(); + + private: + std::ostringstream *stream_; +}; + +} // namespace base + +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { + base::CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// The (int, int) specialization works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +#define DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline std::string* name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (GOOGLE_PREDICT_TRUE(v1 op v2)) return NULL; \ + else return MakeCheckOpString(v1, v2, exprtext); \ + } \ + inline std::string* name##Impl(int v1, int v2, const char* exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } + +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// base/logging.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +DEFINE_CHECK_OP_IMPL(Check_EQ, ==) // Compilation error with CHECK_EQ(NULL, x)? +DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +DEFINE_CHECK_OP_IMPL(Check_LE, <=) +DEFINE_CHECK_OP_IMPL(Check_LT, < ) +DEFINE_CHECK_OP_IMPL(Check_GE, >=) +DEFINE_CHECK_OP_IMPL(Check_GT, > ) +#undef DEFINE_CHECK_OP_IMPL + +// Helper macro for binary operators. +// Don't use this macro directly in your code, use CHECK_EQ et al below. + +#if defined(STATIC_ANALYSIS) +// Only for static analysis tool to know that it is equivalent to assert +#define CHECK_OP_LOG(name, op, val1, val2, log) CHECK((val1) op (val2)) +#elif DCHECK_IS_ON() +// In debug mode, avoid constructing CheckOpStrings if possible, +// to reduce the overhead of CHECK statments by 2x. +// Real DCHECK-heavy tests have seen 1.5x speedups. + +// The meaning of "string" might be different between now and +// when this macro gets invoked (e.g., if someone is experimenting +// with other string implementations that get defined after this +// file is included). Save the current meaning now and use it +// in the macro. +typedef std::string _Check_string; +#define CHECK_OP_LOG(name, op, val1, val2, log) \ + while (google::_Check_string* _result = \ + google::Check##name##Impl( \ + google::GetReferenceableValue(val1), \ + google::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + log(__FILE__, __LINE__, \ + google::CheckOpString(_result)).stream() +#else +// In optimized mode, use CheckOpString to hint to compiler that +// the while condition is unlikely. +#define CHECK_OP_LOG(name, op, val1, val2, log) \ + while (google::CheckOpString _result = \ + google::Check##name##Impl( \ + google::GetReferenceableValue(val1), \ + google::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + log(__FILE__, __LINE__, _result).stream() +#endif // STATIC_ANALYSIS, DCHECK_IS_ON() + +#if GOOGLE_STRIP_LOG <= 3 +#define CHECK_OP(name, op, val1, val2) \ + CHECK_OP_LOG(name, op, val1, val2, google::LogMessageFatal) +#else +#define CHECK_OP(name, op, val1, val2) \ + CHECK_OP_LOG(name, op, val1, val2, google::NullStreamFatal) +#endif // STRIP_LOG <= 3 + +// Equality/Inequality checks - compare two values, and log a FATAL message +// including the two values when the result is not as expected. The values +// must have operator<<(ostream, ...) defined. +// +// You may append to the error message like so: +// CHECK_NE(1, 2) << ": The world must be ending!"; +// +// We are very careful to ensure that each argument is evaluated exactly +// once, and that anything which is legal to pass as a function argument is +// legal here. In particular, the arguments may be temporary expressions +// which will end up being destroyed at the end of the apparent statement, +// for example: +// CHECK_EQ(string("abc")[1], 'b'); +// +// WARNING: These don't compile correctly if one of the arguments is a pointer +// and the other is NULL. To work around this, simply static_cast NULL to the +// type of the desired pointer. + +#define CHECK_EQ(val1, val2) CHECK_OP(_EQ, ==, val1, val2) +#define CHECK_NE(val1, val2) CHECK_OP(_NE, !=, val1, val2) +#define CHECK_LE(val1, val2) CHECK_OP(_LE, <=, val1, val2) +#define CHECK_LT(val1, val2) CHECK_OP(_LT, < , val1, val2) +#define CHECK_GE(val1, val2) CHECK_OP(_GE, >=, val1, val2) +#define CHECK_GT(val1, val2) CHECK_OP(_GT, > , val1, val2) + +// Check that the input is non NULL. This very useful in constructor +// initializer lists. + +#define CHECK_NOTNULL(val) \ + google::CheckNotNull(__FILE__, __LINE__, "'" #val "' Must be non NULL", (val)) + +// Helper functions for string comparisons. +// To avoid bloat, the definitions are in logging.cc. +#define DECLARE_CHECK_STROP_IMPL(func, expected) \ + GOOGLE_GLOG_DLL_DECL std::string* Check##func##expected##Impl( \ + const char* s1, const char* s2, const char* names); +DECLARE_CHECK_STROP_IMPL(strcmp, true) +DECLARE_CHECK_STROP_IMPL(strcmp, false) +DECLARE_CHECK_STROP_IMPL(strcasecmp, true) +DECLARE_CHECK_STROP_IMPL(strcasecmp, false) +#undef DECLARE_CHECK_STROP_IMPL + +// Helper macro for string comparisons. +// Don't use this macro directly in your code, use CHECK_STREQ et al below. +#define CHECK_STROP(func, op, expected, s1, s2) \ + while (google::CheckOpString _result = \ + google::Check##func##expected##Impl((s1), (s2), \ + #s1 " " #op " " #s2)) \ + LOG(FATAL) << *_result.str_ + + +// String (char*) equality/inequality checks. +// CASE versions are case-insensitive. +// +// Note that "s1" and "s2" may be temporary strings which are destroyed +// by the compiler at the end of the current "full expression" +// (e.g. CHECK_STREQ(Foo().c_str(), Bar().c_str())). + +#define CHECK_STREQ(s1, s2) CHECK_STROP(strcmp, ==, true, s1, s2) +#define CHECK_STRNE(s1, s2) CHECK_STROP(strcmp, !=, false, s1, s2) +#define CHECK_STRCASEEQ(s1, s2) CHECK_STROP(strcasecmp, ==, true, s1, s2) +#define CHECK_STRCASENE(s1, s2) CHECK_STROP(strcasecmp, !=, false, s1, s2) + +#define CHECK_INDEX(I,A) CHECK(I < (sizeof(A)/sizeof(A[0]))) +#define CHECK_BOUND(B,A) CHECK(B <= (sizeof(A)/sizeof(A[0]))) + +#define CHECK_DOUBLE_EQ(val1, val2) \ + do { \ + CHECK_LE((val1), (val2)+0.000000000000001L); \ + CHECK_GE((val1), (val2)-0.000000000000001L); \ + } while (0) + +#define CHECK_NEAR(val1, val2, margin) \ + do { \ + CHECK_LE((val1), (val2)+(margin)); \ + CHECK_GE((val1), (val2)-(margin)); \ + } while (0) + +// perror()..googly style! +// +// PLOG() and PLOG_IF() and PCHECK() behave exactly like their LOG* and +// CHECK equivalents with the addition that they postpend a description +// of the current state of errno to their output lines. + +#define PLOG(severity) GOOGLE_PLOG(severity, 0).stream() + +#define GOOGLE_PLOG(severity, counter) \ + google::ErrnoLogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, counter, \ + &google::LogMessage::SendToLog) + +#define PLOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & PLOG(severity) + +// A CHECK() macro that postpends errno if the condition is false. E.g. +// +// if (poll(fds, nfds, timeout) == -1) { PCHECK(errno == EINTR); ... } +#define PCHECK(condition) \ + PLOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!(condition))) \ + << "Check failed: " #condition " " + +// A CHECK() macro that lets you assert the success of a function that +// returns -1 and sets errno in case of an error. E.g. +// +// CHECK_ERR(mkdir(path, 0700)); +// +// or +// +// int fd = open(filename, flags); CHECK_ERR(fd) << ": open " << filename; +#define CHECK_ERR(invocation) \ +PLOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN((invocation) == -1)) \ + << #invocation + +// Use macro expansion to create, for each use of LOG_EVERY_N(), static +// variables with the __LINE__ expansion as part of the variable name. +#define LOG_EVERY_N_VARNAME(base, line) LOG_EVERY_N_VARNAME_CONCAT(base, line) +#define LOG_EVERY_N_VARNAME_CONCAT(base, line) base ## line + +#define LOG_OCCURRENCES LOG_EVERY_N_VARNAME(occurrences_, __LINE__) +#define LOG_OCCURRENCES_MOD_N LOG_EVERY_N_VARNAME(occurrences_mod_n_, __LINE__) + +#define SOME_KIND_OF_LOG_EVERY_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (++LOG_OCCURRENCES_MOD_N > n) LOG_OCCURRENCES_MOD_N -= n; \ + if (LOG_OCCURRENCES_MOD_N == 1) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_LOG_IF_EVERY_N(severity, condition, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (condition && \ + ((LOG_OCCURRENCES_MOD_N=(LOG_OCCURRENCES_MOD_N + 1) % n) == (1 % n))) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_PLOG_EVERY_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (++LOG_OCCURRENCES_MOD_N > n) LOG_OCCURRENCES_MOD_N -= n; \ + if (LOG_OCCURRENCES_MOD_N == 1) \ + google::ErrnoLogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_LOG_FIRST_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0; \ + if (LOG_OCCURRENCES <= n) \ + ++LOG_OCCURRENCES; \ + if (LOG_OCCURRENCES <= n) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +namespace glog_internal_namespace_ { +template +struct CompileAssert { +}; +struct CrashReason; + +// Returns true if FailureSignalHandler is installed. +// Needs to be exported since it's used by the signalhandler_unittest. +GOOGLE_GLOG_DLL_DECL bool IsFailureSignalHandlerInstalled(); +} // namespace glog_internal_namespace_ + +#define LOG_EVERY_N(severity, n) \ + SOME_KIND_OF_LOG_EVERY_N(severity, (n), google::LogMessage::SendToLog) + +#define SYSLOG_EVERY_N(severity, n) \ + SOME_KIND_OF_LOG_EVERY_N(severity, (n), google::LogMessage::SendToSyslogAndLog) + +#define PLOG_EVERY_N(severity, n) \ + SOME_KIND_OF_PLOG_EVERY_N(severity, (n), google::LogMessage::SendToLog) + +#define LOG_FIRST_N(severity, n) \ + SOME_KIND_OF_LOG_FIRST_N(severity, (n), google::LogMessage::SendToLog) + +#define LOG_IF_EVERY_N(severity, condition, n) \ + SOME_KIND_OF_LOG_IF_EVERY_N(severity, (condition), (n), google::LogMessage::SendToLog) + +// We want the special COUNTER value available for LOG_EVERY_X()'ed messages +enum PRIVATE_Counter {COUNTER}; + +#ifdef GLOG_NO_ABBREVIATED_SEVERITIES +// wingdi.h defines ERROR to be 0. When we call LOG(ERROR), it gets +// substituted with 0, and it expands to COMPACT_GOOGLE_LOG_0. To allow us +// to keep using this syntax, we define this macro to do the same thing +// as COMPACT_GOOGLE_LOG_ERROR. +#define COMPACT_GOOGLE_LOG_0 COMPACT_GOOGLE_LOG_ERROR +#define SYSLOG_0 SYSLOG_ERROR +#define LOG_TO_STRING_0 LOG_TO_STRING_ERROR +// Needed for LOG_IS_ON(ERROR). +const LogSeverity GLOG_0 = GLOG_ERROR; +#else +// Users may include windows.h after logging.h without +// GLOG_NO_ABBREVIATED_SEVERITIES nor WIN32_LEAN_AND_MEAN. +// For this case, we cannot detect if ERROR is defined before users +// actually use ERROR. Let's make an undefined symbol to warn users. +# define GLOG_ERROR_MSG ERROR_macro_is_defined_Define_GLOG_NO_ABBREVIATED_SEVERITIES_before_including_logging_h_See_the_document_for_detail +# define COMPACT_GOOGLE_LOG_0 GLOG_ERROR_MSG +# define SYSLOG_0 GLOG_ERROR_MSG +# define LOG_TO_STRING_0 GLOG_ERROR_MSG +# define GLOG_0 GLOG_ERROR_MSG +#endif + +// Plus some debug-logging macros that get compiled to nothing for production + +#if DCHECK_IS_ON() + +#define DLOG(severity) LOG(severity) +#define DVLOG(verboselevel) VLOG(verboselevel) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#define DLOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n) +#define DLOG_IF_EVERY_N(severity, condition, n) \ + LOG_IF_EVERY_N(severity, condition, n) +#define DLOG_ASSERT(condition) LOG_ASSERT(condition) + +// debug-only checking. executed if DCHECK_IS_ON(). +#define DCHECK(condition) CHECK(condition) +#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) +#define DCHECK_NOTNULL(val) CHECK_NOTNULL(val) +#define DCHECK_STREQ(str1, str2) CHECK_STREQ(str1, str2) +#define DCHECK_STRCASEEQ(str1, str2) CHECK_STRCASEEQ(str1, str2) +#define DCHECK_STRNE(str1, str2) CHECK_STRNE(str1, str2) +#define DCHECK_STRCASENE(str1, str2) CHECK_STRCASENE(str1, str2) + +#else // !DCHECK_IS_ON() + +#define DLOG(severity) \ + static_cast(0), \ + true ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DVLOG(verboselevel) \ + static_cast(0), \ + (true || !VLOG_IS_ON(verboselevel)) ? \ + (void) 0 : google::LogMessageVoidify() & LOG(INFO) + +#define DLOG_IF(severity, condition) \ + static_cast(0), \ + (true || !(condition)) ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_EVERY_N(severity, n) \ + static_cast(0), \ + true ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_IF_EVERY_N(severity, condition, n) \ + static_cast(0), \ + (true || !(condition))? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_ASSERT(condition) \ + static_cast(0), \ + true ? (void) 0 : LOG_ASSERT(condition) + +// MSVC warning C4127: conditional expression is constant +#define DCHECK(condition) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK(condition) + +#define DCHECK_EQ(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_EQ(val1, val2) + +#define DCHECK_NE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_NE(val1, val2) + +#define DCHECK_LE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_LE(val1, val2) + +#define DCHECK_LT(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_LT(val1, val2) + +#define DCHECK_GE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_GE(val1, val2) + +#define DCHECK_GT(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_GT(val1, val2) + +// You may see warnings in release mode if you don't use the return +// value of DCHECK_NOTNULL. Please just use DCHECK for such cases. +#define DCHECK_NOTNULL(val) (val) + +#define DCHECK_STREQ(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STREQ(str1, str2) + +#define DCHECK_STRCASEEQ(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRCASEEQ(str1, str2) + +#define DCHECK_STRNE(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRNE(str1, str2) + +#define DCHECK_STRCASENE(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRCASENE(str1, str2) + +#endif // DCHECK_IS_ON() + +// Log only in verbose mode. + +#define VLOG(verboselevel) LOG_IF(INFO, VLOG_IS_ON(verboselevel)) + +#define VLOG_IF(verboselevel, condition) \ + LOG_IF(INFO, (condition) && VLOG_IS_ON(verboselevel)) + +#define VLOG_EVERY_N(verboselevel, n) \ + LOG_IF_EVERY_N(INFO, VLOG_IS_ON(verboselevel), n) + +#define VLOG_IF_EVERY_N(verboselevel, condition, n) \ + LOG_IF_EVERY_N(INFO, (condition) && VLOG_IS_ON(verboselevel), n) + +namespace base_logging { + +// LogMessage::LogStream is a std::ostream backed by this streambuf. +// This class ignores overflow and leaves two bytes at the end of the +// buffer to allow for a '\n' and '\0'. +class GOOGLE_GLOG_DLL_DECL LogStreamBuf : public std::streambuf { + public: + // REQUIREMENTS: "len" must be >= 2 to account for the '\n' and '\0'. + LogStreamBuf(char *buf, int len) { + setp(buf, buf + len - 2); + } + + // This effectively ignores overflow. + virtual int_type overflow(int_type ch) { + return ch; + } + + // Legacy public ostrstream method. + size_t pcount() const { return pptr() - pbase(); } + char* pbase() const { return std::streambuf::pbase(); } +}; + +} // namespace base_logging + +// +// This class more or less represents a particular log message. You +// create an instance of LogMessage and then stream stuff to it. +// When you finish streaming to it, ~LogMessage is called and the +// full message gets streamed to the appropriate destination. +// +// You shouldn't actually use LogMessage's constructor to log things, +// though. You should use the LOG() macro (and variants thereof) +// above. +class GOOGLE_GLOG_DLL_DECL LogMessage { +public: + enum { + // Passing kNoLogPrefix for the line number disables the + // log-message prefix. Useful for using the LogMessage + // infrastructure as a printing utility. See also the --log_prefix + // flag for controlling the log-message prefix on an + // application-wide basis. + kNoLogPrefix = -1 + }; + + // LogStream inherit from non-DLL-exported class (std::ostrstream) + // and VC++ produces a warning for this situation. + // However, MSDN says "C4275 can be ignored in Microsoft Visual C++ + // 2005 if you are deriving from a type in the Standard C++ Library" + // http://msdn.microsoft.com/en-us/library/3tdb471s(VS.80).aspx + // Let's just ignore the warning. +GLOG_MSVC_PUSH_DISABLE_WARNING(4275) + class GOOGLE_GLOG_DLL_DECL LogStream : public std::ostream { +GLOG_MSVC_POP_WARNING() + public: + LogStream(char *buf, int len, int ctr) + : std::ostream(NULL), + streambuf_(buf, len), + ctr_(ctr), + self_(this) { + rdbuf(&streambuf_); + } + + int ctr() const { return ctr_; } + void set_ctr(int ctr) { ctr_ = ctr; } + LogStream* self() const { return self_; } + + // Legacy std::streambuf methods. + size_t pcount() const { return streambuf_.pcount(); } + char* pbase() const { return streambuf_.pbase(); } + char* str() const { return pbase(); } + + private: + LogStream(const LogStream&); + LogStream& operator=(const LogStream&); + base_logging::LogStreamBuf streambuf_; + int ctr_; // Counter hack (for the LOG_EVERY_X() macro) + LogStream *self_; // Consistency check hack + }; + +public: + // icc 8 requires this typedef to avoid an internal compiler error. + typedef void (LogMessage::*SendMethod)(); + + LogMessage(const char* file, int line, LogSeverity severity, int ctr, + SendMethod send_method); + + // Two special constructors that generate reduced amounts of code at + // LOG call sites for common cases. + + // Used for LOG(INFO): Implied are: + // severity = INFO, ctr = 0, send_method = &LogMessage::SendToLog. + // + // Using this constructor instead of the more complex constructor above + // saves 19 bytes per call site. + LogMessage(const char* file, int line); + + // Used for LOG(severity) where severity != INFO. Implied + // are: ctr = 0, send_method = &LogMessage::SendToLog + // + // Using this constructor instead of the more complex constructor above + // saves 17 bytes per call site. + LogMessage(const char* file, int line, LogSeverity severity); + + // Constructor to log this message to a specified sink (if not NULL). + // Implied are: ctr = 0, send_method = &LogMessage::SendToSinkAndLog if + // also_send_to_log is true, send_method = &LogMessage::SendToSink otherwise. + LogMessage(const char* file, int line, LogSeverity severity, LogSink* sink, + bool also_send_to_log); + + // Constructor where we also give a vector pointer + // for storing the messages (if the pointer is not NULL). + // Implied are: ctr = 0, send_method = &LogMessage::SaveOrSendToLog. + LogMessage(const char* file, int line, LogSeverity severity, + std::vector* outvec); + + // Constructor where we also give a string pointer for storing the + // message (if the pointer is not NULL). Implied are: ctr = 0, + // send_method = &LogMessage::WriteToStringAndLog. + LogMessage(const char* file, int line, LogSeverity severity, + std::string* message); + + // A special constructor used for check failures + LogMessage(const char* file, int line, const CheckOpString& result); + + ~LogMessage(); + + // Flush a buffered message to the sink set in the constructor. Always + // called by the destructor, it may also be called from elsewhere if + // needed. Only the first call is actioned; any later ones are ignored. + void Flush(); + + // An arbitrary limit on the length of a single log message. This + // is so that streaming can be done more efficiently. + static const size_t kMaxLogMessageLen; + + // Theses should not be called directly outside of logging.*, + // only passed as SendMethod arguments to other LogMessage methods: + void SendToLog(); // Actually dispatch to the logs + void SendToSyslogAndLog(); // Actually dispatch to syslog and the logs + + // Call abort() or similar to perform LOG(FATAL) crash. + static void __attribute__((noreturn)) Fail(); + + std::ostream& stream(); + + int preserved_errno() const; + + // Must be called without the log_mutex held. (L < log_mutex) + static int64 num_messages(int severity); + + struct LogMessageData; + +private: + // Fully internal SendMethod cases: + void SendToSinkAndLog(); // Send to sink if provided and dispatch to the logs + void SendToSink(); // Send to sink if provided, do nothing otherwise. + + // Write to string if provided and dispatch to the logs. + void WriteToStringAndLog(); + + void SaveOrSendToLog(); // Save to stringvec if provided, else to logs + + void Init(const char* file, int line, LogSeverity severity, + void (LogMessage::*send_method)()); + + // Used to fill in crash information during LOG(FATAL) failures. + void RecordCrashReason(glog_internal_namespace_::CrashReason* reason); + + // Counts of messages sent at each priority: + static int64 num_messages_[NUM_SEVERITIES]; // under log_mutex + + // We keep the data in a separate struct so that each instance of + // LogMessage uses less stack space. + LogMessageData* allocated_; + LogMessageData* data_; + + friend class LogDestination; + + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +// This class happens to be thread-hostile because all instances share +// a single data buffer, but since it can only be created just before +// the process dies, we don't worry so much. +class GOOGLE_GLOG_DLL_DECL LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line); + LogMessageFatal(const char* file, int line, const CheckOpString& result); + __attribute__((noreturn)) ~LogMessageFatal(); +}; + +// A non-macro interface to the log facility; (useful +// when the logging level is not a compile-time constant). +inline void LogAtLevel(int const severity, std::string const &msg) { + LogMessage(__FILE__, __LINE__, severity).stream() << msg; +} + +// A macro alternative of LogAtLevel. New code may want to use this +// version since there are two advantages: 1. this version outputs the +// file name and the line number where this macro is put like other +// LOG macros, 2. this macro can be used as C++ stream. +#define LOG_AT_LEVEL(severity) google::LogMessage(__FILE__, __LINE__, severity).stream() + +// Check if it's compiled in C++11 mode. +// +// GXX_EXPERIMENTAL_CXX0X is defined by gcc and clang up to at least +// gcc-4.7 and clang-3.1 (2011-12-13). __cplusplus was defined to 1 +// in gcc before 4.7 (Crosstool 16) and clang before 3.1, but is +// defined according to the language version in effect thereafter. +// Microsoft Visual Studio 14 (2015) sets __cplusplus==199711 despite +// reasonably good C++11 support, so we set LANG_CXX for it and +// newer versions (_MSC_VER >= 1900). +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900)) +// Helper for CHECK_NOTNULL(). +// +// In C++11, all cases can be handled by a single function. Since the value +// category of the argument is preserved (also for rvalue references), +// member initializer lists like the one below will compile correctly: +// +// Foo() +// : x_(CHECK_NOTNULL(MethodReturningUniquePtr())) {} +template +T CheckNotNull(const char* file, int line, const char* names, T&& t) { + if (t == nullptr) { + LogMessageFatal(file, line, new std::string(names)); + } + return std::forward(t); +} + +#else + +// A small helper for CHECK_NOTNULL(). +template +T* CheckNotNull(const char *file, int line, const char *names, T* t) { + if (t == NULL) { + LogMessageFatal(file, line, new std::string(names)); + } + return t; +} +#endif + +// Allow folks to put a counter in the LOG_EVERY_X()'ed messages. This +// only works if ostream is a LogStream. If the ostream is not a +// LogStream you'll get an assert saying as much at runtime. +GOOGLE_GLOG_DLL_DECL std::ostream& operator<<(std::ostream &os, + const PRIVATE_Counter&); + + +// Derived class for PLOG*() above. +class GOOGLE_GLOG_DLL_DECL ErrnoLogMessage : public LogMessage { + public: + + ErrnoLogMessage(const char* file, int line, LogSeverity severity, int ctr, + void (LogMessage::*send_method)()); + + // Postpends ": strerror(errno) [errno]". + ~ErrnoLogMessage(); + + private: + ErrnoLogMessage(const ErrnoLogMessage&); + void operator=(const ErrnoLogMessage&); +}; + + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". + +class GOOGLE_GLOG_DLL_DECL LogMessageVoidify { + public: + LogMessageVoidify() { } + // This has to be an operator with a precedence lower than << but + // higher than ?: + void operator&(std::ostream&) { } +}; + + +// Flushes all log files that contains messages that are at least of +// the specified severity level. Thread-safe. +GOOGLE_GLOG_DLL_DECL void FlushLogFiles(LogSeverity min_severity); + +// Flushes all log files that contains messages that are at least of +// the specified severity level. Thread-hostile because it ignores +// locking -- used for catastrophic failures. +GOOGLE_GLOG_DLL_DECL void FlushLogFilesUnsafe(LogSeverity min_severity); + +// +// Set the destination to which a particular severity level of log +// messages is sent. If base_filename is "", it means "don't log this +// severity". Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogDestination(LogSeverity severity, + const char* base_filename); + +// +// Set the basename of the symlink to the latest log file at a given +// severity. If symlink_basename is empty, do not make a symlink. If +// you don't call this function, the symlink basename is the +// invocation name of the program. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogSymlink(LogSeverity severity, + const char* symlink_basename); + +// +// Used to send logs to some other kind of destination +// Users should subclass LogSink and override send to do whatever they want. +// Implementations must be thread-safe because a shared instance will +// be called from whichever thread ran the LOG(XXX) line. +class GOOGLE_GLOG_DLL_DECL LogSink { + public: + virtual ~LogSink(); + + // Sink's logging logic (message_len is such as to exclude '\n' at the end). + // This method can't use LOG() or CHECK() as logging system mutex(s) are held + // during this call. + virtual void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, + const struct ::tm* tm_time, + const char* message, size_t message_len) = 0; + + // Redefine this to implement waiting for + // the sink's logging logic to complete. + // It will be called after each send() returns, + // but before that LogMessage exits or crashes. + // By default this function does nothing. + // Using this function one can implement complex logic for send() + // that itself involves logging; and do all this w/o causing deadlocks and + // inconsistent rearrangement of log messages. + // E.g. if a LogSink has thread-specific actions, the send() method + // can simply add the message to a queue and wake up another thread that + // handles real logging while itself making some LOG() calls; + // WaitTillSent() can be implemented to wait for that logic to complete. + // See our unittest for an example. + virtual void WaitTillSent(); + + // Returns the normal text output of the log message. + // Can be useful to implement send(). + static std::string ToString(LogSeverity severity, const char* file, int line, + const struct ::tm* tm_time, + const char* message, size_t message_len); +}; + +// Add or remove a LogSink as a consumer of logging data. Thread-safe. +GOOGLE_GLOG_DLL_DECL void AddLogSink(LogSink *destination); +GOOGLE_GLOG_DLL_DECL void RemoveLogSink(LogSink *destination); + +// +// Specify an "extension" added to the filename specified via +// SetLogDestination. This applies to all severity levels. It's +// often used to append the port we're listening on to the logfile +// name. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogFilenameExtension( + const char* filename_extension); + +// +// Make it so that all log messages of at least a particular severity +// are logged to stderr (in addition to logging to the usual log +// file(s)). Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetStderrLogging(LogSeverity min_severity); + +// +// Make it so that all log messages go only to stderr. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void LogToStderr(); + +// +// Make it so that all log messages of at least a particular severity are +// logged via email to a list of addresses (in addition to logging to the +// usual log file(s)). The list of addresses is just a string containing +// the email addresses to send to (separated by spaces, say). Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetEmailLogging(LogSeverity min_severity, + const char* addresses); + +// A simple function that sends email. dest is a commma-separated +// list of addressess. Thread-safe. +GOOGLE_GLOG_DLL_DECL bool SendEmail(const char *dest, + const char *subject, const char *body); + +GOOGLE_GLOG_DLL_DECL const std::vector& GetLoggingDirectories(); + +// For tests only: Clear the internal [cached] list of logging directories to +// force a refresh the next time GetLoggingDirectories is called. +// Thread-hostile. +void TestOnly_ClearLoggingDirectoriesList(); + +// Returns a set of existing temporary directories, which will be a +// subset of the directories returned by GetLogginDirectories(). +// Thread-safe. +GOOGLE_GLOG_DLL_DECL void GetExistingTempDirectories( + std::vector* list); + +// Print any fatal message again -- useful to call from signal handler +// so that the last thing in the output is the fatal message. +// Thread-hostile, but a race is unlikely. +GOOGLE_GLOG_DLL_DECL void ReprintFatalMessage(); + +// Truncate a log file that may be the append-only output of multiple +// processes and hence can't simply be renamed/reopened (typically a +// stdout/stderr). If the file "path" is > "limit" bytes, copy the +// last "keep" bytes to offset 0 and truncate the rest. Since we could +// be racing with other writers, this approach has the potential to +// lose very small amounts of data. For security, only follow symlinks +// if the path is /proc/self/fd/* +GOOGLE_GLOG_DLL_DECL void TruncateLogFile(const char *path, + int64 limit, int64 keep); + +// Truncate stdout and stderr if they are over the value specified by +// --max_log_size; keep the final 1MB. This function has the same +// race condition as TruncateLogFile. +GOOGLE_GLOG_DLL_DECL void TruncateStdoutStderr(); + +// Return the string representation of the provided LogSeverity level. +// Thread-safe. +GOOGLE_GLOG_DLL_DECL const char* GetLogSeverityName(LogSeverity severity); + +// --------------------------------------------------------------------- +// Implementation details that are not useful to most clients +// --------------------------------------------------------------------- + +// A Logger is the interface used by logging modules to emit entries +// to a log. A typical implementation will dump formatted data to a +// sequence of files. We also provide interfaces that will forward +// the data to another thread so that the invoker never blocks. +// Implementations should be thread-safe since the logging system +// will write to them from multiple threads. + +namespace base { + +class GOOGLE_GLOG_DLL_DECL Logger { + public: + virtual ~Logger(); + + // Writes "message[0,message_len-1]" corresponding to an event that + // occurred at "timestamp". If "force_flush" is true, the log file + // is flushed immediately. + // + // The input message has already been formatted as deemed + // appropriate by the higher level logging facility. For example, + // textual log messages already contain timestamps, and the + // file:linenumber header. + virtual void Write(bool force_flush, + time_t timestamp, + const char* message, + int message_len) = 0; + + // Flush any buffered messages + virtual void Flush() = 0; + + // Get the current LOG file size. + // The returned value is approximate since some + // logged data may not have been flushed to disk yet. + virtual uint32 LogSize() = 0; +}; + +// Get the logger for the specified severity level. The logger +// remains the property of the logging module and should not be +// deleted by the caller. Thread-safe. +extern GOOGLE_GLOG_DLL_DECL Logger* GetLogger(LogSeverity level); + +// Set the logger for the specified severity level. The logger +// becomes the property of the logging module and should not +// be deleted by the caller. Thread-safe. +extern GOOGLE_GLOG_DLL_DECL void SetLogger(LogSeverity level, Logger* logger); + +} + +// glibc has traditionally implemented two incompatible versions of +// strerror_r(). There is a poorly defined convention for picking the +// version that we want, but it is not clear whether it even works with +// all versions of glibc. +// So, instead, we provide this wrapper that automatically detects the +// version that is in use, and then implements POSIX semantics. +// N.B. In addition to what POSIX says, we also guarantee that "buf" will +// be set to an empty string, if this function failed. This means, in most +// cases, you do not need to check the error code and you can directly +// use the value of "buf". It will never have an undefined value. +// DEPRECATED: Use StrError(int) instead. +GOOGLE_GLOG_DLL_DECL int posix_strerror_r(int err, char *buf, size_t len); + +// A thread-safe replacement for strerror(). Returns a string describing the +// given POSIX error code. +GOOGLE_GLOG_DLL_DECL std::string StrError(int err); + +// A class for which we define operator<<, which does nothing. +class GOOGLE_GLOG_DLL_DECL NullStream : public LogMessage::LogStream { + public: + // Initialize the LogStream so the messages can be written somewhere + // (they'll never be actually displayed). This will be needed if a + // NullStream& is implicitly converted to LogStream&, in which case + // the overloaded NullStream::operator<< will not be invoked. + NullStream() : LogMessage::LogStream(message_buffer_, 1, 0) { } + NullStream(const char* /*file*/, int /*line*/, + const CheckOpString& /*result*/) : + LogMessage::LogStream(message_buffer_, 1, 0) { } + NullStream &stream() { return *this; } + private: + // A very short buffer for messages (which we discard anyway). This + // will be needed if NullStream& converted to LogStream& (e.g. as a + // result of a conditional expression). + char message_buffer_[2]; +}; + +// Do nothing. This operator is inline, allowing the message to be +// compiled away. The message will not be compiled away if we do +// something like (flag ? LOG(INFO) : LOG(ERROR)) << message; when +// SKIP_LOG=WARNING. In those cases, NullStream will be implicitly +// converted to LogStream and the message will be computed and then +// quietly discarded. +template +inline NullStream& operator<<(NullStream &str, const T &) { return str; } + +// Similar to NullStream, but aborts the program (without stack +// trace), like LogMessageFatal. +class GOOGLE_GLOG_DLL_DECL NullStreamFatal : public NullStream { + public: + NullStreamFatal() { } + NullStreamFatal(const char* file, int line, const CheckOpString& result) : + NullStream(file, line, result) { } + __attribute__((noreturn)) ~NullStreamFatal() throw () { _exit(1); } +}; + +// Install a signal handler that will dump signal information and a stack +// trace when the program crashes on certain signals. We'll install the +// signal handler for the following signals. +// +// SIGSEGV, SIGILL, SIGFPE, SIGABRT, SIGBUS, and SIGTERM. +// +// By default, the signal handler will write the failure dump to the +// standard error. You can customize the destination by installing your +// own writer function by InstallFailureWriter() below. +// +// Note on threading: +// +// The function should be called before threads are created, if you want +// to use the failure signal handler for all threads. The stack trace +// will be shown only for the thread that receives the signal. In other +// words, stack traces of other threads won't be shown. +GOOGLE_GLOG_DLL_DECL void InstallFailureSignalHandler(); + +// Installs a function that is used for writing the failure dump. "data" +// is the pointer to the beginning of a message to be written, and "size" +// is the size of the message. You should not expect the data is +// terminated with '\0'. +GOOGLE_GLOG_DLL_DECL void InstallFailureWriter( + void (*writer)(const char* data, int size)); + +} + +#endif // _LOGGING_H_ diff --git a/projects/llm_framework/include/glog/raw_logging.h b/projects/llm_framework/include/glog/raw_logging.h new file mode 100644 index 00000000..cf3f27d9 --- /dev/null +++ b/projects/llm_framework/include/glog/raw_logging.h @@ -0,0 +1,180 @@ +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Maxim Lifantsev +// +// Thread-safe logging routines that do not allocate any memory or +// acquire any locks, and can therefore be used by low-level memory +// allocation and synchronization code. + +#ifndef BASE_RAW_LOGGING_H_ +#define BASE_RAW_LOGGING_H_ + +#include + +namespace google { + +#include "glog/log_severity.h" +#include "glog/vlog_is_on.h" + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// This is similar to LOG(severity) << format... and VLOG(level) << format.., +// but +// * it is to be used ONLY by low-level modules that can't use normal LOG() +// * it is desiged to be a low-level logger that does not allocate any +// memory and does not need any locks, hence: +// * it logs straight and ONLY to STDERR w/o buffering +// * it uses an explicit format and arguments list +// * it will silently chop off really long message strings +// Usage example: +// RAW_LOG(ERROR, "Failed foo with %i: %s", status, error); +// RAW_VLOG(3, "status is %i", status); +// These will print an almost standard log lines like this to stderr only: +// E0821 211317 file.cc:123] RAW: Failed foo with 22: bad_file +// I0821 211317 file.cc:142] RAW: status is 20 +#define RAW_LOG(severity, ...) \ + do { \ + switch (google::GLOG_ ## severity) { \ + case 0: \ + RAW_LOG_INFO(__VA_ARGS__); \ + break; \ + case 1: \ + RAW_LOG_WARNING(__VA_ARGS__); \ + break; \ + case 2: \ + RAW_LOG_ERROR(__VA_ARGS__); \ + break; \ + case 3: \ + RAW_LOG_FATAL(__VA_ARGS__); \ + break; \ + default: \ + break; \ + } \ + } while (0) + +// The following STRIP_LOG testing is performed in the header file so that it's +// possible to completely compile out the logging code and the log messages. +#if STRIP_LOG == 0 +#define RAW_VLOG(verboselevel, ...) \ + do { \ + if (VLOG_IS_ON(verboselevel)) { \ + RAW_LOG_INFO(__VA_ARGS__); \ + } \ + } while (0) +#else +#define RAW_VLOG(verboselevel, ...) RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG == 0 + +#if STRIP_LOG == 0 +#define RAW_LOG_INFO(...) google::RawLog__(google::GLOG_INFO, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_INFO(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG == 0 + +#if STRIP_LOG <= 1 +#define RAW_LOG_WARNING(...) google::RawLog__(google::GLOG_WARNING, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_WARNING(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG <= 1 + +#if STRIP_LOG <= 2 +#define RAW_LOG_ERROR(...) google::RawLog__(google::GLOG_ERROR, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_ERROR(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG <= 2 + +#if STRIP_LOG <= 3 +#define RAW_LOG_FATAL(...) google::RawLog__(google::GLOG_FATAL, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_FATAL(...) \ + do { \ + google::RawLogStub__(0, __VA_ARGS__); \ + exit(1); \ + } while (0) +#endif // STRIP_LOG <= 3 + +// Similar to CHECK(condition) << message, +// but for low-level modules: we use only RAW_LOG that does not allocate memory. +// We do not want to provide args list here to encourage this usage: +// if (!cond) RAW_LOG(FATAL, "foo ...", hard_to_compute_args); +// so that the args are not computed when not needed. +#define RAW_CHECK(condition, message) \ + do { \ + if (!(condition)) { \ + RAW_LOG(FATAL, "Check %s failed: %s", #condition, message); \ + } \ + } while (0) + +// Debug versions of RAW_LOG and RAW_CHECK +#ifndef NDEBUG + +#define RAW_DLOG(severity, ...) RAW_LOG(severity, __VA_ARGS__) +#define RAW_DCHECK(condition, message) RAW_CHECK(condition, message) + +#else // NDEBUG + +#define RAW_DLOG(severity, ...) \ + while (false) \ + RAW_LOG(severity, __VA_ARGS__) +#define RAW_DCHECK(condition, message) \ + while (false) \ + RAW_CHECK(condition, message) + +#endif // NDEBUG + +// Stub log function used to work around for unused variable warnings when +// building with STRIP_LOG > 0. +static inline void RawLogStub__(int /* ignored */, ...) { +} + +// Helper function to implement RAW_LOG and RAW_VLOG +// Logs format... at "severity" level, reporting it +// as called from file:line. +// This does not allocate memory or acquire locks. +GOOGLE_GLOG_DLL_DECL void RawLog__(LogSeverity severity, + const char* file, + int line, + const char* format, ...) + ; + +} + +#endif // BASE_RAW_LOGGING_H_ diff --git a/projects/llm_framework/include/glog/stl_logging.h b/projects/llm_framework/include/glog/stl_logging.h new file mode 100644 index 00000000..40a15aa4 --- /dev/null +++ b/projects/llm_framework/include/glog/stl_logging.h @@ -0,0 +1,220 @@ +// Copyright (c) 2003, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Stream output operators for STL containers; to be used for logging *only*. +// Inclusion of this file lets you do: +// +// list x; +// LOG(INFO) << "data: " << x; +// vector v1, v2; +// CHECK_EQ(v1, v2); +// +// If you want to use this header file with hash maps or slist, you +// need to define macros before including this file: +// +// - GLOG_STL_LOGGING_FOR_UNORDERED - and +// - GLOG_STL_LOGGING_FOR_TR1_UNORDERED - +// - GLOG_STL_LOGGING_FOR_EXT_HASH - +// - GLOG_STL_LOGGING_FOR_EXT_SLIST - +// + +#ifndef UTIL_GTL_STL_LOGGING_INL_H_ +#define UTIL_GTL_STL_LOGGING_INL_H_ + +#if !1 +# error We do not support stl_logging for this compiler +#endif + +#include +#include +#include +#include +#include +#include +#include + +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +# include +# include +#endif + +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +# include +# include +#endif + +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +# include +# include +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_SLIST +# include +#endif + +// Forward declare these two, and define them after all the container streams +// operators so that we can recurse from pair -> container -> container -> pair +// properly. +template +std::ostream& operator<<(std::ostream& out, const std::pair& p); + +namespace google { + +template +void PrintSequence(std::ostream& out, Iter begin, Iter end); + +} + +#define OUTPUT_TWO_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_TWO_ARG_CONTAINER(std::vector) +OUTPUT_TWO_ARG_CONTAINER(std::deque) +OUTPUT_TWO_ARG_CONTAINER(std::list) +#ifdef GLOG_STL_LOGGING_FOR_EXT_SLIST +OUTPUT_TWO_ARG_CONTAINER(__gnu_cxx::slist) +#endif + +#undef OUTPUT_TWO_ARG_CONTAINER + +#define OUTPUT_THREE_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_THREE_ARG_CONTAINER(std::set) +OUTPUT_THREE_ARG_CONTAINER(std::multiset) + +#undef OUTPUT_THREE_ARG_CONTAINER + +#define OUTPUT_FOUR_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_FOUR_ARG_CONTAINER(std::map) +OUTPUT_FOUR_ARG_CONTAINER(std::multimap) +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +OUTPUT_FOUR_ARG_CONTAINER(std::unordered_set) +OUTPUT_FOUR_ARG_CONTAINER(std::unordered_multiset) +#endif +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +OUTPUT_FOUR_ARG_CONTAINER(std::tr1::unordered_set) +OUTPUT_FOUR_ARG_CONTAINER(std::tr1::unordered_multiset) +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +OUTPUT_FOUR_ARG_CONTAINER(__gnu_cxx::hash_set) +OUTPUT_FOUR_ARG_CONTAINER(__gnu_cxx::hash_multiset) +#endif + +#undef OUTPUT_FOUR_ARG_CONTAINER + +#define OUTPUT_FIVE_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +OUTPUT_FIVE_ARG_CONTAINER(std::unordered_map) +OUTPUT_FIVE_ARG_CONTAINER(std::unordered_multimap) +#endif +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +OUTPUT_FIVE_ARG_CONTAINER(std::tr1::unordered_map) +OUTPUT_FIVE_ARG_CONTAINER(std::tr1::unordered_multimap) +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +OUTPUT_FIVE_ARG_CONTAINER(__gnu_cxx::hash_map) +OUTPUT_FIVE_ARG_CONTAINER(__gnu_cxx::hash_multimap) +#endif + +#undef OUTPUT_FIVE_ARG_CONTAINER + +template +inline std::ostream& operator<<(std::ostream& out, + const std::pair& p) { + out << '(' << p.first << ", " << p.second << ')'; + return out; +} + +namespace google { + +template +inline void PrintSequence(std::ostream& out, Iter begin, Iter end) { + // Output at most 100 elements -- appropriate if used for logging. + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } +} + +} + +// Note that this is technically undefined behavior! We are adding things into +// the std namespace for a reason though -- we are providing new operations on +// types which are themselves defined with this namespace. Without this, these +// operator overloads cannot be found via ADL. If these definitions are not +// found via ADL, they must be #included before they're used, which requires +// this header to be included before apparently independent other headers. +// +// For example, base/logging.h defines various template functions to implement +// CHECK_EQ(x, y) and stream x and y into the log in the event the check fails. +// It does so via the function template MakeCheckOpValueString: +// template +// void MakeCheckOpValueString(strstream* ss, const T& v) { +// (*ss) << v; +// } +// Because 'glog/logging.h' is included before 'glog/stl_logging.h', +// subsequent CHECK_EQ(v1, v2) for vector<...> typed variable v1 and v2 can only +// find these operator definitions via ADL. +// +// Even this solution has problems -- it may pull unintended operators into the +// namespace as well, allowing them to also be found via ADL, and creating code +// that only works with a particular order of includes. Long term, we need to +// move all of the *definitions* into namespace std, bet we need to ensure no +// one references them first. This lets us take that step. We cannot define them +// in both because that would create ambiguous overloads when both are found. +namespace std { using ::operator<<; } + +#endif // UTIL_GTL_STL_LOGGING_INL_H_ diff --git a/projects/llm_framework/include/glog/vlog_is_on.h b/projects/llm_framework/include/glog/vlog_is_on.h new file mode 100644 index 00000000..02b0b867 --- /dev/null +++ b/projects/llm_framework/include/glog/vlog_is_on.h @@ -0,0 +1,129 @@ +// Copyright (c) 1999, 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Ray Sidney and many others +// +// Defines the VLOG_IS_ON macro that controls the variable-verbosity +// conditional logging. +// +// It's used by VLOG and VLOG_IF in logging.h +// and by RAW_VLOG in raw_logging.h to trigger the logging. +// +// It can also be used directly e.g. like this: +// if (VLOG_IS_ON(2)) { +// // do some logging preparation and logging +// // that can't be accomplished e.g. via just VLOG(2) << ...; +// } +// +// The truth value that VLOG_IS_ON(level) returns is determined by +// the three verbosity level flags: +// --v= Gives the default maximal active V-logging level; +// 0 is the default. +// Normally positive values are used for V-logging levels. +// --vmodule= Gives the per-module maximal V-logging levels to override +// the value given by --v. +// E.g. "my_module=2,foo*=3" would change the logging level +// for all code in source files "my_module.*" and "foo*.*" +// ("-inl" suffixes are also disregarded for this matching). +// +// SetVLOGLevel helper function is provided to do limited dynamic control over +// V-logging by overriding the per-module settings given via --vmodule flag. +// +// CAVEAT: --vmodule functionality is not available in non gcc compilers. +// + +#ifndef BASE_VLOG_IS_ON_H_ +#define BASE_VLOG_IS_ON_H_ + +#include "glog/log_severity.h" + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +#if defined(__GNUC__) +// We emit an anonymous static int* variable at every VLOG_IS_ON(n) site. +// (Normally) the first time every VLOG_IS_ON(n) site is hit, +// we determine what variable will dynamically control logging at this site: +// it's either FLAGS_v or an appropriate internal variable +// matching the current source file that represents results of +// parsing of --vmodule flag and/or SetVLOGLevel calls. +#define VLOG_IS_ON(verboselevel) \ + __extension__ \ + ({ static google::int32* vlocal__ = &google::kLogSiteUninitialized; \ + google::int32 verbose_level__ = (verboselevel); \ + (*vlocal__ >= verbose_level__) && \ + ((vlocal__ != &google::kLogSiteUninitialized) || \ + (google::InitVLOG3__(&vlocal__, &FLAGS_v, \ + __FILE__, verbose_level__))); }) +#else +// GNU extensions not available, so we do not support --vmodule. +// Dynamic value of FLAGS_v always controls the logging level. +#define VLOG_IS_ON(verboselevel) (FLAGS_v >= (verboselevel)) +#endif + +// Set VLOG(_IS_ON) level for module_pattern to log_level. +// This lets us dynamically control what is normally set by the --vmodule flag. +// Returns the level that previously applied to module_pattern. +// NOTE: To change the log level for VLOG(_IS_ON) sites +// that have already executed after/during InitGoogleLogging, +// one needs to supply the exact --vmodule pattern that applied to them. +// (If no --vmodule pattern applied to them +// the value of FLAGS_v will continue to control them.) +extern GOOGLE_GLOG_DLL_DECL int SetVLOGLevel(const char* module_pattern, + int log_level); + +// Various declarations needed for VLOG_IS_ON above: ========================= + +// Special value used to indicate that a VLOG_IS_ON site has not been +// initialized. We make this a large value, so the common-case check +// of "*vlocal__ >= verbose_level__" in VLOG_IS_ON definition +// passes in such cases and InitVLOG3__ is then triggered. +extern google::int32 kLogSiteUninitialized; + +// Helper routine which determines the logging info for a particalur VLOG site. +// site_flag is the address of the site-local pointer to the controlling +// verbosity level +// site_default is the default to use for *site_flag +// fname is the current source file name +// verbose_level is the argument to VLOG_IS_ON +// We will return the return value for VLOG_IS_ON +// and if possible set *site_flag appropriately. +extern GOOGLE_GLOG_DLL_DECL bool InitVLOG3__( + google::int32** site_flag, + google::int32* site_default, + const char* fname, + google::int32 verbose_level); + +#endif // BASE_VLOG_IS_ON_H_ diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index 6663ca30..87886e09 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -25,9 +25,12 @@ REQUIREMENTS += ['samplerate'] INCLUDE += [ADir('../include')] INCLUDE += [ADir('src/runner'), ADir('../include/onnxruntime/core/session')] - +LINK_SEARCH_PATH += [ADir('../static_lib/wetext')] LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a'] +LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a','-l:libglog.so','-l:libfst.so'] + + +LDFLAGS += [] STATIC_FILES += Glob('mode_*.json') diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-default.json b/projects/llm_framework/main_melotts/mode_melotts-en-default.json index 18945145..2bf39dd6 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-en-default.json +++ b/projects/llm_framework/main_melotts/mode_melotts-en-default.json @@ -21,6 +21,8 @@ "gbin": "g-en-default.bin", "tokens": "tokens-en.txt", "lexicon": "lexicon-en.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", "spacker_speed": 1.2, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-us.json b/projects/llm_framework/main_melotts/mode_melotts-en-us.json index 6a375c93..d6320873 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-en-us.json +++ b/projects/llm_framework/main_melotts/mode_melotts-en-us.json @@ -1,9 +1,9 @@ { "mode": "melotts-en-us", "type": "tts", - "homepage":"https://huggingface.co/myshell-ai/MeloTTS-English", - "compile_flage":"pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder-en --output_name decoder-en.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", - "pulsar_version":"3.4-3dfd5692", + "homepage": "https://huggingface.co/myshell-ai/MeloTTS-English", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder-en --output_name decoder-en.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "3.4-3dfd5692", "capabilities": [ "tts", "English" @@ -21,6 +21,8 @@ "gbin": "g-en.bin", "tokens": "tokens.txt", "lexicon": "lexicon.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", "spacker_speed": 1.0, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json b/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json index d2df3e12..0b93f91e 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json +++ b/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json @@ -21,6 +21,8 @@ "gbin": "g-jp.bin", "tokens": "tokens-jp.txt", "lexicon": "lexicon-jp.txt", + "tagger": "ja_tn_tagger.fst", + "verbalizer": "ja_tn_verbalizer.fst", "spacker_speed": 1.1, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json b/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json index b5edfe02..17867b92 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json +++ b/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json @@ -21,6 +21,8 @@ "gbin": "g-zh_mix_en.bin", "tokens": "tokens.txt", "lexicon": "lexicon.txt", + "tagger": "zh_tn_tagger.fst", + "verbalizer": "zh_tn_verbalizer.fst", "spacker_speed": 1.1, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index 4c25df80..0875be97 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -9,7 +9,6 @@ #include "Lexicon.hpp" #include #include "AudioFile.h" -#include "Lexicon.hpp" #include #include @@ -44,6 +43,8 @@ typedef struct { std::string tokens; std::string gbin; std::string sentence; + std::string tagger; + std::string verbalizer; float spacker_speed = 1.0; int mode_rate = 44100; int audio_rate = 16000; @@ -169,17 +170,22 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], length_scale); CONFIG_AUTO_SET(file_body["mode_param"], noise_scale_w); CONFIG_AUTO_SET(file_body["mode_param"], sdp_ratio); - mode_config_.tokens = base_model + mode_config_.tokens; - mode_config_.gbin = base_model + mode_config_.gbin; - mode_config_.encoder = base_model + mode_config_.encoder; - mode_config_.decoder = base_model + mode_config_.decoder; - mode_config_.lexicon = base_model + mode_config_.lexicon; + CONFIG_AUTO_SET(file_body["mode_param"], tagger); + CONFIG_AUTO_SET(file_body["mode_param"], verbalizer); + mode_config_.tokens = base_model + mode_config_.tokens; + mode_config_.gbin = base_model + mode_config_.gbin; + mode_config_.encoder = base_model + mode_config_.encoder; + mode_config_.decoder = base_model + mode_config_.decoder; + mode_config_.lexicon = base_model + mode_config_.lexicon; + mode_config_.tagger = base_model + mode_config_.tagger; + mode_config_.verbalizer = base_model + mode_config_.verbalizer; if (config_body.contains("awake_delay")) awake_delay_ = config_body["awake_delay"].get(); else if (file_body["mode_param"].contains("awake_delay")) awake_delay_ = file_body["mode_param"]["awake_delay"]; // Load lexicon - lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens); + lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens, mode_config_.tagger, + mode_config_.verbalizer); // Read g.bin g_matrix.resize(256, 0); FILE *fp = fopen(mode_config_.gbin.c_str(), "rb"); @@ -243,7 +249,6 @@ class llm_task { try { std::vector wav_pcm_data; if (msg_str.empty()) { - SLOGI("empty"); if (out_callback_) { std::string output = wav_pcm_data.empty() ? std::string() : std::string((char *)wav_pcm_data.data(), @@ -253,7 +258,6 @@ class llm_task { return false; } - // Convert text to phonemes and tones std::vector phones_bef, tones_bef; lexicon_->convert(msg_str, phones_bef, tones_bef); auto phones = intersperse(phones_bef, 0); @@ -261,7 +265,6 @@ class llm_task { int phone_len = phones.size(); std::vector langids(phone_len, 3); - // Run the encoder to generate hidden representations auto encoder_output = encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w, mode_config_.get_length_scale(), mode_config_.sdp_ratio); @@ -270,7 +273,6 @@ class llm_task { auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo(); auto zp_shape = zp_info.GetShape(); - // Calculate decoder parameters int zp_size = decoder_->GetInputSize(0) / sizeof(float); int dec_len = zp_size / zp_shape[1]; int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float); @@ -283,12 +285,10 @@ class llm_task { int dec_slice_num = static_cast(std::ceil(static_cast(zp_shape[2]) / static_cast(effective_frames))); - // SOLA parameters setup - const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length - const int sola_search_frame = pad_frames * samples_per_frame; // Search window length - const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length + const int sola_buffer_frame = pad_frames * samples_per_frame; + const int sola_search_frame = pad_frames * samples_per_frame; + const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; - // Create fade-in/fade-out windows for smooth transitions std::vector fade_in_window(sola_buffer_frame); std::vector fade_out_window(sola_buffer_frame); @@ -297,46 +297,35 @@ class llm_task { fade_out_window[i] = 1.0f - fade_in_window[i]; } - // Initialize SOLA buffer std::vector sola_buffer(sola_buffer_frame, 0.0f); bool first_frame = true; std::vector pcmlist; - // Main decoding loop - process each slice for (int i = 0; i < dec_slice_num; i++) { - // Calculate start position for current batch input int input_start = i * effective_frames; - // Consider forward padding, but ensure non-negative if (i > 0) { input_start -= pad_frames; } input_start = std::max(0, input_start); - // Actual input length int actual_len = std::min(dec_len, static_cast(zp_shape[2] - input_start)); - // Calculate effective output range (frame level) int output_start_frame, output_end_frame; if (i == 0) { - // First frame: skip padding at beginning output_start_frame = 0; output_end_frame = effective_frames - 1; } else if (i == dec_slice_num - 1) { - // Last frame: calculate from current segment start output_start_frame = i * effective_frames; - // Last frame extends to encoder's maximum output length - output_end_frame = static_cast(zp_shape[2]) - 1; + output_end_frame = static_cast(zp_shape[2]) - 1; } else { - // Middle frames: standard calculation output_start_frame = i * effective_frames; output_end_frame = (i + 1) * effective_frames - 1; } - // Prepare decoder input, initialize all to zero + std::vector zp(zp_size, 0); - // Copy data to decoder input for (int n = 0; n < zp_shape[1]; n++) { int copy_size = std::min(actual_len, static_cast(zp_shape[2] - input_start)); if (copy_size > 0) { @@ -345,54 +334,37 @@ class llm_task { } } - // Run decoder std::vector decoder_output(audio_slice_len); decoder_->SetInput(zp.data(), 0); decoder_->SetInput(g_matrix.data(), 1); if (0 != decoder_->Run()) { - SLOGI("Inference #%d: decoding failed", i + 1); throw std::string("decoder_ RunSync error"); } decoder_->GetOutput(decoder_output.data(), 0); - // === SOLA Processing Logic === if (first_frame) { - // Special handling for first frame - should not skip initial content - // First frame starts directly from decoder output without skipping - int audio_start = 0; // Start from beginning, don't skip pad_frames - - // Calculate data length for first frame - // First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end - // for next frame alignment - int audio_len = decoder_output.size() - sola_buffer_frame; + int audio_start = 0; + int audio_len = decoder_output.size() - sola_buffer_frame; + audio_len = std::max(0, audio_len); - // Boundary check - audio_len = std::max(0, audio_len); // Ensure non-negative - - // Add first frame data if (audio_len > 0) { pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start, decoder_output.begin() + audio_start + audio_len); } - // Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment int buffer_start = audio_len; - // Ensure sufficient data is available for copying if (buffer_start + sola_buffer_frame <= decoder_output.size()) { std::copy(decoder_output.begin() + buffer_start, decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin()); } else { - // Possible case: first frame data is shorter than sola_buffer_frame int available = static_cast(decoder_output.size() - buffer_start); if (available > 0) { std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin()); - // Fill with zeros std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f); } else { - // Completely insufficient data, fill all with zeros std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f); } } @@ -400,15 +372,12 @@ class llm_task { first_frame = false; } else { - // Non-first frame: SOLA alignment required int audio_start = pad_frames * samples_per_frame; - // 1. Prepare search window - beginning portion of current frame std::vector search_window(sola_buffer_frame + sola_search_frame); std::copy(decoder_output.begin() + audio_start, decoder_output.begin() + audio_start + search_window.size(), search_window.begin()); - // 2. Find best alignment point (calculate cross-correlation) int best_offset = 0; float best_correlation = -1.0; @@ -421,7 +390,6 @@ class llm_task { energy += search_window[j + offset] * search_window[j + offset]; } - // Normalize correlation (avoid division by zero) float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f; if (normalized_correlation > best_correlation) { @@ -430,30 +398,25 @@ class llm_task { } } - // 3. Apply alignment offset + int aligned_start = audio_start + best_offset; - // 4. Smooth transition processing (crossfade in alignment region) std::vector crossfade_region(sola_buffer_frame); for (int j = 0; j < sola_buffer_frame; j++) { - // Apply fade-in/fade-out window functions crossfade_region[j] = decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j]; } - // 5. Add crossfade region to output pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end()); int remaining_start = aligned_start + sola_buffer_frame; if (i == dec_slice_num - 1) { int total_expected_samples = audio_len * samples_per_frame / 512; - - int processed_samples = static_cast(pcmlist.size()); - - int remaining_needed = total_expected_samples - processed_samples; - remaining_needed = std::max(0, remaining_needed); + int processed_samples = static_cast(pcmlist.size()); + int remaining_needed = total_expected_samples - processed_samples; + remaining_needed = std::max(0, remaining_needed); int remaining_len = std::min(remaining_needed, static_cast(decoder_output.size() - remaining_start)); @@ -465,7 +428,6 @@ class llm_task { } else { int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame; - remaining_len = std::min(remaining_len, static_cast(decoder_output.size() - remaining_start)); @@ -495,7 +457,7 @@ class llm_task { pcmlist.resize(audio_len); } - // Post-processing: resample and convert to int16 + double src_ratio = static_cast(mode_config_.audio_rate) / static_cast(mode_config_.mode_rate); std::vector tmp_pcm((pcmlist.size() * src_ratio + 1)); @@ -503,12 +465,11 @@ class llm_task { resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio); - // Convert to 16-bit PCM + wav_pcm_data.reserve(len); std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data), [](const auto val) { return static_cast(val * INT16_MAX); }); - // Call the output callback function with the result if (out_callback_) { out_callback_( std::string(reinterpret_cast(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)), @@ -516,10 +477,8 @@ class llm_task { } } catch (const std::exception &e) { - SLOGI("TTS processing exception: %s", e.what()); return true; } catch (...) { - SLOGI("TTS processing encountered an unknown exception"); return true; } return false; @@ -932,4 +891,4 @@ int main(int argc, char *argv[]) } llm.llm_firework_exit(); return 0; -} \ No newline at end of file +} diff --git a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp index d1bcbe90..134f64c4 100644 --- a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp +++ b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp @@ -8,6 +8,8 @@ #include #include #include "../../../../../SDK/components/utilities/include/sample_log.h" +#include "processor/wetext_processor.h" + // Debug logging switch - set to true to enable debug logs static bool DEBUG_LOGGING = false; // Macro for debug logging @@ -36,16 +38,23 @@ class Lexicon { std::pair, std::vector> unknown_token; std::unordered_map reverse_tokens; + wetext::Processor* m_processor; + public: // Setter for debug logging static void setDebugLogging(bool enable) { DEBUG_LOGGING = enable; } - Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0) + Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename, const std::string& tagger_filename, + const std::string& verbalizer_filename) + : max_phrase_length(0) { - DEBUG_LOG("Dictionary loading: %s Pronunciation table loading: %s", tokens_filename.c_str(), - lexicon_filename.c_str()); + DEBUG_LOG("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", + tokens_filename.c_str(), lexicon_filename.c_str(), tagger_filename.c_str(), + verbalizer_filename.c_str()); + + m_processor = new wetext::Processor(tagger_filename, verbalizer_filename); std::unordered_map tokens; std::ifstream ifs(tokens_filename); @@ -198,6 +207,12 @@ class Lexicon { void convert(const std::string& text, std::vector& phones, std::vector& tones) { DEBUG_LOG("\nStarting text processing: \"%s\"", text.c_str()); + + std::string taggedText = m_processor->Tag(text); + DEBUG_LOG("\taggedText processing: \"%s\"", taggedText.c_str()); + std::string normalizedText = m_processor->Verbalize(taggedText); + DEBUG_LOG("\normalizedText processing: \"%s\"", normalizedText.c_str()); + DEBUG_LOG("=======Matching Results======="); DEBUG_LOG("Unit\t|\tPhonemes\t|\tTones"); DEBUG_LOG("-----------------------------"); @@ -205,7 +220,7 @@ class Lexicon { tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); DEBUG_LOG("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), tonesToString(unknown_token.second).c_str()); - auto chars = splitEachChar(text); + auto chars = splitEachChar(normalizedText); int i = 0; while (i < chars.size()) { if (is_english(chars[i])) { diff --git a/projects/llm_framework/main_melotts/src/runner/base64.cpp b/projects/llm_framework/main_melotts/src/runner/base64.cpp index 5e0fd6ad..e8e1add3 100644 --- a/projects/llm_framework/main_melotts/src/runner/base64.cpp +++ b/projects/llm_framework/main_melotts/src/runner/base64.cpp @@ -1,17 +1,13 @@ #include "base64.h" static uint8 alphabet_map[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -static uint8 reverse_map[] = -{ -255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63, - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, - 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255, - 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255 -}; +static uint8 reverse_map[] = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, + 255, 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, 255, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255}; // //GB2312到UTF-8的转换 // char* G2U(const char* gb2312) @@ -42,26 +38,27 @@ static uint8 reverse_map[] = // return str; // } -// uint32 base64_encode(char* input, uint8* encode) +// base64_uint32 base64_encode(char* input, uint8* encode) // { // //1、包含中文的字符串 字符编码(windows默认是gbk)转换成unicode - + // //2、字符编码方式是utf-8的二进制 // // uint8* text = (uint8*)G2U(input); -// uint32 text_len = (uint32)strlen((char*)input); +// base64_uint32 text_len = (base64_uint32)strlen((char*)input); -// uint32 i, j; +// base64_uint32 i, j; // for (i = 0, j = 0; i + 3 <= text_len; i += 3) // { -// encode[j++] = alphabet_map[text[i] >> 2]; //取出第一个字符的前6位并找出对应的结果字符 -// encode[j++] = alphabet_map[((text[i] << 4) & 0x30) | (text[i + 1] >> 4)]; //将第一个字符的后2位与第二个字符的前4位进行组合并找到对应的结果字符 -// encode[j++] = alphabet_map[((text[i + 1] << 2) & 0x3c) | (text[i + 2] >> 6)]; //将第二个字符的后4位与第三个字符的前2位组合并找出对应的结果字符 -// encode[j++] = alphabet_map[text[i + 2] & 0x3f]; //取出第三个字符的后6位并找出结果字符 +// encode[j++] = alphabet_map[text[i] >> 2]; //取出第一个字符的前6位并找出对应的结果字符 encode[j++] = +// alphabet_map[((text[i] << 4) & 0x30) | (text[i + 1] >> 4)]; +// //将第一个字符的后2位与第二个字符的前4位进行组合并找到对应的结果字符 encode[j++] = alphabet_map[((text[i + 1] << 2) & +// 0x3c) | (text[i + 2] >> 6)]; //将第二个字符的后4位与第三个字符的前2位组合并找出对应的结果字符 encode[j++] = +// alphabet_map[text[i + 2] & 0x3f]; //取出第三个字符的后6位并找出结果字符 // } // if (i < text_len) // { -// uint32 tail = text_len - i; +// base64_uint32 tail = text_len - i; // if (tail == 1) // { // encode[j++] = alphabet_map[text[i] >> 2]; @@ -81,40 +78,41 @@ static uint8 reverse_map[] = // return j; // } -int base64_decode(const uint8* code, uint32 code_len, char* str) +int base64_decode(const uint8* code, base64_uint32 code_len, char* str) { - uint8 plain[1024]; - assert((code_len & 0x03) == 0); //如果它的条件返回错误,则终止程序执行。4的倍数。 + uint8 plain[1024]; + assert((code_len & 0x03) == 0); // 如果它的条件返回错误,则终止程序执行。4的倍数。 - uint32 i, j = 0; - uint8 quad[4]; - for (i = 0; i < code_len; i += 4) - { - for (uint32 k = 0; k < 4; k++) - { - quad[k] = reverse_map[code[i + k]];//分组,每组四个分别依次转换为base64表内的十进制数 - } + base64_uint32 i, j = 0; + uint8 quad[4]; + for (i = 0; i < code_len; i += 4) { + for (base64_uint32 k = 0; k < 4; k++) { + quad[k] = reverse_map[code[i + k]]; // 分组,每组四个分别依次转换为base64表内的十进制数 + } - assert(quad[0] < 64 && quad[1] < 64); + assert(quad[0] < 64 && quad[1] < 64); - plain[j++] = (quad[0] << 2) | (quad[1] >> 4); //取出第一个字符对应base64表的十进制数的前6位与第二个字符对应base64表的十进制数的前2位进行组合 + plain[j++] = + (quad[0] << 2) | + (quad[1] >> + 4); // 取出第一个字符对应base64表的十进制数的前6位与第二个字符对应base64表的十进制数的前2位进行组合 - if (quad[2] >= 64) - break; - else if (quad[3] >= 64) - { - plain[j++] = (quad[1] << 4) | (quad[2] >> 2); //取出第二个字符对应base64表的十进制数的后4位与第三个字符对应base64表的十进制数的前4位进行组合 - break; - } - else - { - plain[j++] = (quad[1] << 4) | (quad[2] >> 2); - plain[j++] = (quad[2] << 6) | quad[3];//取出第三个字符对应base64表的十进制数的后2位与第4个字符进行组合 - } - } - plain[j] = 0; - // char str[1024] = ""; - strcpy(str, (char*)plain); - // strcpy_s(str, sizeof(plain), U2G(str)); - return j; + if (quad[2] >= 64) + break; + else if (quad[3] >= 64) { + plain[j++] = + (quad[1] << 4) | + (quad[2] >> + 2); // 取出第二个字符对应base64表的十进制数的后4位与第三个字符对应base64表的十进制数的前4位进行组合 + break; + } else { + plain[j++] = (quad[1] << 4) | (quad[2] >> 2); + plain[j++] = (quad[2] << 6) | quad[3]; // 取出第三个字符对应base64表的十进制数的后2位与第4个字符进行组合 + } + } + plain[j] = 0; + // char str[1024] = ""; + strcpy(str, (char*)plain); + // strcpy_s(str, sizeof(plain), U2G(str)); + return j; } \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/src/runner/base64.h b/projects/llm_framework/main_melotts/src/runner/base64.h index 8e3dcb6c..f7a01c7e 100644 --- a/projects/llm_framework/main_melotts/src/runner/base64.h +++ b/projects/llm_framework/main_melotts/src/runner/base64.h @@ -2,10 +2,10 @@ #include #include -#include +#include #include #include -typedef unsigned char uint8; -typedef unsigned long uint32; -// uint32 base64_encode(char* input, uint8* encode); -int base64_decode(const uint8* code, uint32 code_len, char* str); \ No newline at end of file +typedef unsigned char uint8; +typedef unsigned long base64_uint32; +// base64_uint32 base64_encode(char* input, uint8* encode); +int base64_decode(const uint8* code, base64_uint32 code_len, char* str); \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt b/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt new file mode 100644 index 00000000..a2e8d97c --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(wetext_processor STATIC + wetext_processor.cc + wetext_token_parser.cc +) +if(ANDROID) + target_link_libraries(wetext_processor PUBLIC fst wetext_utils) +else() + if(MSVC) + target_link_libraries(wetext_processor PUBLIC fst wetext_utils) + else() + target_link_libraries(wetext_processor PUBLIC dl fst wetext_utils) + endif() +endif() diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc new file mode 100644 index 00000000..eec45a24 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processor/wetext_processor.h" + +using fst::StringTokenType; + +namespace wetext { +Processor::Processor(const std::string& tagger_path, const std::string& verbalizer_path) +{ + tagger_.reset(StdVectorFst::Read(tagger_path)); + verbalizer_.reset(StdVectorFst::Read(verbalizer_path)); + compiler_ = std::make_shared>(StringTokenType::BYTE); + printer_ = std::make_shared>(StringTokenType::BYTE); + + if (tagger_path.find("zh_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_TN; + } else if (tagger_path.find("zh_itn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_ITN; + } else if (tagger_path.find("en_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kEN_TN; + } else if (tagger_path.find("ja_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_TN; // 如果是日语的文件开始,也使用中文的规则进行转换 + } else { + LOG(FATAL) << "Invalid fst prefix, prefix should contain" << " either \"_tn_\" or \"_itn_\"."; + } +} + +std::string Processor::ShortestPath(const StdVectorFst& lattice) +{ + StdVectorFst shortest_path; + fst::ShortestPath(lattice, &shortest_path, 1, true); + + std::string output; + printer_->operator()(shortest_path, &output); + return output; +} + +std::string Processor::Compose(const std::string& input, const StdVectorFst* fst) +{ + StdVectorFst input_fst; + compiler_->operator()(input, &input_fst); + + StdVectorFst lattice; + fst::Compose(input_fst, *fst, &lattice); + return ShortestPath(lattice); +} + +std::string Processor::Tag(const std::string& input) +{ + if (input.empty()) { + return ""; + } + return Compose(input, tagger_.get()); +} + +std::string Processor::Verbalize(const std::string& input) +{ + if (input.empty()) { + return ""; + } + TokenParser parser(parse_type_); + std::string output = parser.Reorder(input); + + output = Compose(output, verbalizer_.get()); + output.erase(std::remove(output.begin(), output.end(), '\0'), output.end()); + return output; +} + +std::string Processor::Normalize(const std::string& input) +{ + return Verbalize(Tag(input)); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h new file mode 100644 index 00000000..e11d307e --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROCESSOR_WETEXT_PROCESSOR_H_ +#define PROCESSOR_WETEXT_PROCESSOR_H_ + +#include +#include + +#include "fst/fstlib.h" + +#include "processor/wetext_token_parser.h" + +using fst::StdArc; +using fst::StdVectorFst; +using fst::StringCompiler; +using fst::StringPrinter; + +namespace wetext { +class Processor { + public: + Processor(const std::string& tagger_path, const std::string& verbalizer_path); + std::string Tag(const std::string& input); + std::string Verbalize(const std::string& input); + std::string Normalize(const std::string& input); + + private: + std::string ShortestPath(const StdVectorFst& lattice); + std::string Compose(const std::string& input, const StdVectorFst* fst); + + ParseType parse_type_; + std::shared_ptr tagger_ = nullptr; + std::shared_ptr verbalizer_ = nullptr; + std::shared_ptr> compiler_ = nullptr; + std::shared_ptr> printer_ = nullptr; +}; + +} // namespace wetext + +#endif // PROCESSOR_WETEXT_PROCESSOR_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc new file mode 100644 index 00000000..a600eead --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processor/wetext_token_parser.h" + +#include "utils/wetext_log.h" +#include "utils/wetext_string.h" + +namespace wetext { +const char EOS[] = ""; +const std::set UTF8_WHITESPACE = {" ", "\t", "\n", "\r", + "\x0b\x0c"}; +const std::set ASCII_LETTERS = { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", + "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", + "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", + "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "_"}; +const std::unordered_map> ZH_TN_ORDERS = { + {"date", {"year", "month", "day"}}, + {"fraction", {"denominator", "numerator"}}, + {"measure", {"denominator", "numerator", "value"}}, + {"money", {"value", "currency"}}, + {"time", {"noon", "hour", "minute", "second"}}}; +const std::unordered_map> EN_TN_ORDERS = { + {"date", {"preserve_order", "text", "day", "month", "year"}}, + {"money", {"integer_part", "fractional_part", "quantity", "currency_maj"}}}; +const std::unordered_map> ZH_ITN_ORDERS = + {{"date", {"year", "month", "day"}}, + {"fraction", {"sign", "numerator", "denominator"}}, + {"measure", {"numerator", "denominator", "value"}}, + {"money", {"currency", "value", "decimal"}}, + {"time", {"hour", "minute", "second", "noon"}}}; + +TokenParser::TokenParser(ParseType type) { + if (type == ParseType::kZH_TN) { + orders_ = ZH_TN_ORDERS; + } else if (type == ParseType::kZH_ITN) { + orders_ = ZH_ITN_ORDERS; + } else if (type == ParseType::kEN_TN) { + orders_ = EN_TN_ORDERS; + } else { + LOG(FATAL) << "Invalid order"; + } +} + +void TokenParser::Load(const std::string& input) { + wetext::SplitUTF8StringToChars(input, &text_); + CHECK_GT(text_.size(), 0); + index_ = 0; + ch_ = text_[0]; +} + +bool TokenParser::Read() { + if (index_ < text_.size() - 1) { + index_ += 1; + ch_ = text_[index_]; + return true; + } + ch_ = EOS; + return false; +} + +bool TokenParser::ParseWs() { + bool not_eos = ch_ != EOS; + while (not_eos && ch_ == " ") { + not_eos = Read(); + } + return not_eos; +} + +bool TokenParser::ParseChar(const std::string& exp) { + if (ch_ == exp) { + Read(); + return true; + } + return false; +} + +bool TokenParser::ParseChars(const std::string& exp) { + bool ok = false; + std::vector chars; + wetext::SplitUTF8StringToChars(exp, &chars); + for (const auto& x : chars) { + ok |= ParseChar(x); + } + return ok; +} + +std::string TokenParser::ParseKey() { + CHECK_NE(ch_, EOS); + CHECK_EQ(UTF8_WHITESPACE.count(ch_), 0); + + std::string key = ""; + while (ASCII_LETTERS.count(ch_) > 0) { + key += ch_; + Read(); + } + return key; +} + +std::string TokenParser::ParseValue() { + CHECK_NE(ch_, EOS); + bool escape = false; + + std::string value = ""; + while (ch_ != "\"") { + value += ch_; + escape = ch_ == "\\"; + Read(); + if (escape) { + escape = false; + value += ch_; + Read(); + } + } + return value; +} + +void TokenParser::Parse(const std::string& input) { + Load(input); + while (ParseWs()) { + std::string name = ParseKey(); + ParseChars(" { "); + + Token token(name); + while (ParseWs()) { + if (ch_ == "}") { + ParseChar("}"); + break; + } + std::string key = ParseKey(); + ParseChars(": \""); + std::string value = ParseValue(); + ParseChar("\""); + token.Append(key, value); + } + tokens_.emplace_back(token); + } +} + +std::string TokenParser::Reorder(const std::string& input) { + Parse(input); + std::string output = ""; + for (auto& token : tokens_) { + output += token.String(orders_) + " "; + } + return Trim(output); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h new file mode 100644 index 00000000..34aba979 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROCESSOR_WETEXT_TOKEN_PARSER_H_ +#define PROCESSOR_WETEXT_TOKEN_PARSER_H_ + +#include +#include +#include +#include + +namespace wetext { + +extern const char EOS[]; +extern const std::set UTF8_WHITESPACE; +extern const std::set ASCII_LETTERS; +extern const std::unordered_map> + ZH_TN_ORDERS; +extern const std::unordered_map> + ZH_ITN_ORDERS; +extern const std::unordered_map> + EN_TN_ORDERS; + +struct Token { + std::string name; + std::vector order; + std::unordered_map members; + + explicit Token(const std::string& name) : name(name) {} + + void Append(const std::string& key, const std::string& value) { + order.emplace_back(key); + members[key] = value; + } + + std::string String( + const std::unordered_map>& orders) { + std::string output = name + " {"; + if (orders.count(name) > 0) { + order = orders.at(name); + } + + for (const auto& key : order) { + if (members.count(key) == 0) { + continue; + } + output += " " + key + ": \"" + members[key] + "\""; + } + return output + " }"; + } +}; + +enum ParseType { + kZH_TN = 0x00, // Chinese Text Normalization + kZH_ITN = 0x01, // Chinese Inverse Text Normalization + kEN_TN = 0x02 // English Text Normalization +}; + +class TokenParser { + public: + explicit TokenParser(ParseType type); + std::string Reorder(const std::string& input); + + private: + void Load(const std::string& input); + bool Read(); + bool ParseWs(); + bool ParseChar(const std::string& exp); + bool ParseChars(const std::string& exp); + std::string ParseKey(); + std::string ParseValue(); + void Parse(const std::string& input); + + int index_; + std::string ch_; + std::vector text_; + std::vector tokens_; + std::unordered_map> orders_; +}; + +} // namespace wetext + +#endif // PROCESSOR_WETEXT_TOKEN_PARSER_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt b/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt new file mode 100644 index 00000000..30071f4c --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(wetext_utils STATIC wetext_string.cc) + +target_link_libraries(wetext_utils PUBLIC glog) diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h new file mode 100644 index 00000000..c1d30df3 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_FLAGS_H_ +#define UTILS_WETEXT_FLAGS_H_ + +// Because openfst is a dynamic library compiled with gflags/glog, we must use +// the gflags/glog from openfst to avoid them linked both statically and +// dynamically into the executable. +#include "fst/flags.h" + +#endif // UTILS_WETEXT_FLAGS_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h new file mode 100644 index 00000000..b47a6a48 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_LOG_H_ +#define UTILS_WETEXT_LOG_H_ + +// Because openfst is a dynamic library compiled with gflags/glog, we must use +// the gflags/glog from openfst to avoid them linked both statically and +// dynamically into the executable. +#include "fst/log.h" + +#endif // UTILS_WETEXT_LOG_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc new file mode 100644 index 00000000..4df9ec91 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/wetext_string.h" + +#include "utils/wetext_log.h" + +namespace wetext { +const char* WHITESPACE = " \n\r\t\f\v"; + +int UTF8CharLength(char ch) { + int num_bytes = 1; + CHECK_LE((ch & 0xF8), 0xF0); + if ((ch & 0x80) == 0x00) { + // The first 128 characters (US-ASCII) in UTF-8 format only need one byte. + num_bytes = 1; + } else if ((ch & 0xE0) == 0xC0) { + // The next 1,920 characters need two bytes to encode, + // which covers the remainder of almost all Latin-script alphabets. + num_bytes = 2; + } else if ((ch & 0xF0) == 0xE0) { + // Three bytes are needed for characters in the rest of + // the Basic Multilingual Plane, which contains virtually all characters + // in common use, including most Chinese, Japanese and Korean characters. + num_bytes = 3; + } else if ((ch & 0xF8) == 0xF0) { + // Four bytes are needed for characters in the other planes of Unicode, + // which include less common CJK characters, various historic scripts, + // mathematical symbols, and emoji (pictographic symbols). + num_bytes = 4; + } + return num_bytes; +} + +int UTF8StringLength(const std::string& str) { + int len = 0; + int num_bytes = 1; + for (size_t i = 0; i < str.length(); i += num_bytes) { + num_bytes = UTF8CharLength(str[i]); + ++len; + } + return len; +} + +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars) { + chars->clear(); + int num_bytes = 1; + for (size_t i = 0; i < str.length(); i += num_bytes) { + num_bytes = UTF8CharLength(str[i]); + chars->push_back(str.substr(i, num_bytes)); + } +} + +std::string Ltrim(const std::string& str) { + size_t start = str.find_first_not_of(WHITESPACE); + return (start == std::string::npos) ? "" : str.substr(start); +} + +std::string Rtrim(const std::string& str) { + size_t end = str.find_last_not_of(WHITESPACE); + return end == std::string::npos ? "" : str.substr(0, end + 1); +} + +std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); } + +void Split(const std::string& str, const std::string& delim, + std::vector* output) { + std::string s = str; + size_t pos = 0; + while ((pos = s.find(delim)) != std::string::npos) { + output->emplace_back(s.substr(0, pos)); + s.erase(0, pos + delim.length()); + } + output->emplace_back(s); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h new file mode 100644 index 00000000..ae890d60 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_STRING_H_ +#define UTILS_WETEXT_STRING_H_ + +#include +#include + +namespace wetext { +extern const char* WHITESPACE; + +int UTF8CharLength(char ch); + +int UTF8StringLength(const std::string& str); + +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars); + +std::string Ltrim(const std::string& str); + +std::string Rtrim(const std::string& str); + +std::string Trim(const std::string& str); + +void Split(const std::string& str, const std::string& delim, + std::vector* output); + +} // namespace wetext + +#endif // UTILS_WETEXT_STRING_H_