diff --git a/AGENTS.md b/AGENTS.md index c3b0d5b47c..80038fd49a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -122,6 +122,8 @@ Tests require PostgreSQL running locally. Start it with `docker compose up -d po Keep domain test helpers under `openmeter/.../testutils` independent from `app/common`. Build test dependencies from the underlying package constructors (repos, adapters, services, `lockr`) instead of importing the application wiring layer, or unrelated wiring additions can create test-only import cycles. +For usage-based billing lifecycle tests, prefer driving behavior through `charges.Service.Create`, `AdvanceCharges`, and `ApplyPatches` rather than calling lower-level charge adapters directly. To model late-arriving or newly visible usage, use `MockStreamingConnector` events with explicit `StoredAt` values (or `SetSimpleEvents`) so the test exercises the real stored-at cutoff logic in finalization. + For OpenMeter Go tests that touch the database, explicitly set `POSTGRES_HOST=127.0.0.1`. Without it, many suites will skip during setup even if PostgreSQL is running and the repo environment is otherwise loaded correctly. Use the repo's Nix CI dev shell when `go`, `gofmt`, or other toolchain binaries are missing from the ambient shell. The CI and local-compatible invocation pattern is: @@ -180,6 +182,8 @@ All builds use `GO_BUILD_FLAGS=-tags=dynamic`. See the `/service` skill for service/adapter patterns, constructors, input types, errors, transactions, hooks, logging, multi-tenancy, and DI wiring. See the `/api` skill for HTTP handler patterns and ValidationIssue. See the `/ent` skill for Ent ORM patterns and Postgres type gotchas. See the `/ledger` skill for ledger package architecture, wiring, and testing. See the `/subscription` skill for subscription domain model, sync algorithm, patch system, workflow layer, and addon sub-system. See the `/notification` skill for notification event pipeline, Kafka consumers, Svix webhook delivery, reconciliation loop, and payload versioning. +In `openmeter/billing/charges/.../adapter`, keep Ent access transaction-aware even in shared helper functions. If a helper accepts a raw `*entdb.Client`, still wrap its body with `entutils.TransactingRepo(...)` / `TransactingRepoWithNoValue(...)` so it rebinds to the transaction already carried in `ctx` instead of depending on the caller to pass a tx-specific client. + ## Key Dependencies | Category | Libraries | diff --git a/app/common/charges.go b/app/common/charges.go index 341ffe47c0..164c8dd548 100644 --- a/app/common/charges.go +++ b/app/common/charges.go @@ -15,6 +15,9 @@ import ( flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" flatfeelineengine "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/lineengine" flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" chargesservice "github.com/openmeterio/openmeter/openmeter/billing/charges/service" @@ -26,6 +29,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger" ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" ledgerchargeadapter "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/framework/lockr" @@ -46,12 +51,30 @@ func NewChargesMetaAdapter( return metaAdapter, nil } +func NewChargesCollectorService( + ledgerService ledger.Ledger, + accountResolver ledger.AccountResolver, + accountService ledgeraccount.Service, +) ledgercollector.Service { + return ledgercollector.NewService(ledgercollector.Config{ + Ledger: ledgerService, + Dependencies: transactions.ResolverDependencies{ + AccountService: accountResolver, + SubAccountService: accountService, + }, + }) +} + func NewChargesFlatFeeHandler( ledgerService ledger.Ledger, accountResolver ledger.AccountResolver, accountService ledgeraccount.Service, + collectorService ledgercollector.Service, ) flatfee.Handler { - return ledgerchargeadapter.NewFlatFeeHandler(ledgerService, accountResolver, accountService) + return ledgerchargeadapter.NewFlatFeeHandler(ledgerService, transactions.ResolverDependencies{ + AccountService: accountResolver, + SubAccountService: accountService, + }, collectorService) } func NewChargesCreditPurchaseHandler( @@ -62,8 +85,8 @@ func NewChargesCreditPurchaseHandler( return ledgerchargeadapter.NewCreditPurchaseHandler(ledgerService, accountResolver, accountService) } -func NewChargesUsageBasedHandler() usagebased.Handler { - return usagebased.UnimplementedHandler{} +func NewChargesUsageBasedHandler(collectorService ledgercollector.Service) usagebased.Handler { + return ledgerchargeadapter.NewUsageBasedHandler(collectorService) } func NewChargesFlatFeeAdapter( @@ -83,15 +106,43 @@ func NewChargesFlatFeeAdapter( return flatFeeAdapter, nil } +func NewChargesLineageAdapter( + db *entdb.Client, +) (lineage.Adapter, error) { + lineageAdapter, err := lineageadapter.New(lineageadapter.Config{ + Client: db, + }) + if err != nil { + return nil, fmt.Errorf("failed to create charges lineage adapter: %w", err) + } + + return lineageAdapter, nil +} + +func NewChargesLineageService( + lineageAdapter lineage.Adapter, +) (lineage.Service, error) { + lineageService, err := lineageservice.New(lineageservice.Config{ + Adapter: lineageAdapter, + }) + if err != nil { + return nil, fmt.Errorf("failed to create charges lineage service: %w", err) + } + + return lineageService, nil +} + func NewChargesFlatFeeService( flatFeeAdapter flatfee.Adapter, flatFeeHandler flatfee.Handler, + lineageService lineage.Service, metaAdapter meta.Adapter, locker *lockr.Locker, ) (flatfee.Service, error) { flatFeeSvc, err := flatfeeservice.New(flatfeeservice.Config{ Adapter: flatFeeAdapter, Handler: flatFeeHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, Locker: locker, }) @@ -122,6 +173,7 @@ func NewChargesUsageBasedAdapter( func NewChargesUsageBasedService( usageBasedAdapter usagebased.Adapter, usageBasedHandler usagebased.Handler, + lineageService lineage.Service, locker *lockr.Locker, metaAdapter meta.Adapter, billingService billing.Service, @@ -132,6 +184,7 @@ func NewChargesUsageBasedService( usageBasedSvc, err := usagebasedservice.New(usagebasedservice.Config{ Adapter: usageBasedAdapter, Handler: usageBasedHandler, + Lineage: lineageService, Locker: locker, MetaAdapter: metaAdapter, CustomerOverrideService: billingService, @@ -166,11 +219,13 @@ func NewChargesCreditPurchaseAdapter( func NewChargesCreditPurchaseService( creditPurchaseAdapter creditpurchase.Adapter, creditPurchaseHandler creditpurchase.Handler, + lineageService lineage.Service, metaAdapter meta.Adapter, ) (creditpurchase.Service, error) { creditPurchaseSvc, err := creditpurchaseservice.New(creditpurchaseservice.Config{ Adapter: creditPurchaseAdapter, Handler: creditPurchaseHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, }) if err != nil { @@ -242,16 +297,27 @@ func newChargesRegistry( return nil, err } - flatFeeHandler := NewChargesFlatFeeHandler(ledgerService, accountResolver, accountService) + lineageAdapter, err := NewChargesLineageAdapter(db) + if err != nil { + return nil, err + } + + lineageService, err := NewChargesLineageService(lineageAdapter) + if err != nil { + return nil, err + } + + collectorService := NewChargesCollectorService(ledgerService, accountResolver, accountService) + flatFeeHandler := NewChargesFlatFeeHandler(ledgerService, accountResolver, accountService, collectorService) creditPurchaseHandler := NewChargesCreditPurchaseHandler(ledgerService, accountResolver, accountService) - usageBasedHandler := NewChargesUsageBasedHandler() + usageBasedHandler := NewChargesUsageBasedHandler(collectorService) flatFeeAdapter, err := NewChargesFlatFeeAdapter(db, logger, metaAdapter) if err != nil { return nil, err } - flatFeeSvc, err := NewChargesFlatFeeService(flatFeeAdapter, flatFeeHandler, metaAdapter, locker) + flatFeeSvc, err := NewChargesFlatFeeService(flatFeeAdapter, flatFeeHandler, lineageService, metaAdapter, locker) if err != nil { return nil, err } @@ -276,6 +342,7 @@ func newChargesRegistry( usageBasedSvc, err := NewChargesUsageBasedService( usageBasedAdapter, usageBasedHandler, + lineageService, locker, metaAdapter, billingService, @@ -292,7 +359,7 @@ func newChargesRegistry( return nil, err } - creditPurchaseSvc, err := NewChargesCreditPurchaseService(creditPurchaseAdapter, creditPurchaseHandler, metaAdapter) + creditPurchaseSvc, err := NewChargesCreditPurchaseService(creditPurchaseAdapter, creditPurchaseHandler, lineageService, metaAdapter) if err != nil { return nil, err } diff --git a/app/common/customerbalance.go b/app/common/customerbalance.go index 94ebe561ae..8f67bbd0da 100644 --- a/app/common/customerbalance.go +++ b/app/common/customerbalance.go @@ -12,6 +12,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" @@ -22,7 +24,9 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger" ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" ledgerchargeadapter "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/framework/lockr" @@ -59,6 +63,20 @@ func NewCustomerBalanceService( return nil, err } + lineageAdapter, err := lineageadapter.New(lineageadapter.Config{ + Client: db, + }) + if err != nil { + return nil, err + } + + lineageService, err := lineageservice.New(lineageservice.Config{ + Adapter: lineageAdapter, + }) + if err != nil { + return nil, err + } + searchAdapter, err := chargeadapter.New(chargeadapter.Config{ Client: db, Logger: logger, @@ -76,9 +94,18 @@ func NewCustomerBalanceService( return nil, err } + collectorService := ledgercollector.NewService(ledgercollector.Config{ + Ledger: historicalLedger, + Dependencies: transactions.ResolverDependencies{ + AccountService: accountResolver, + SubAccountService: accountService, + }, + }) + flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ Adapter: flatFeeAdapter, - Handler: ledgerchargeadapter.NewFlatFeeHandler(historicalLedger, accountResolver, accountService), + Handler: ledgerchargeadapter.NewFlatFeeHandler(historicalLedger, transactions.ResolverDependencies{AccountService: accountResolver, SubAccountService: accountService}, collectorService), + Lineage: lineageService, MetaAdapter: metaAdapter, Locker: locker, }) @@ -97,7 +124,8 @@ func NewCustomerBalanceService( usageService, err := usagebasedservice.New(usagebasedservice.Config{ Adapter: usageAdapter, - Handler: usagebased.UnimplementedHandler{}, + Handler: ledgerchargeadapter.NewUsageBasedHandler(collectorService), + Lineage: lineageService, Locker: locker, MetaAdapter: metaAdapter, CustomerOverrideService: billingRegistry.Billing, diff --git a/openmeter/billing/charges/creditpurchase/service/external.go b/openmeter/billing/charges/creditpurchase/service/external.go index 483ea9f83d..edf8bd3b30 100644 --- a/openmeter/billing/charges/creditpurchase/service/external.go +++ b/openmeter/billing/charges/creditpurchase/service/external.go @@ -4,6 +4,7 @@ import ( "context" "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/payment" @@ -19,20 +20,38 @@ func (s *service) onExternalCreditPurchase(ctx context.Context, charge creditpur targetStatus := externalCreditPurchaseSettlement.InitialStatus - // Let's handle the initial state - ledgerTransactionGroupReference, err := s.handler.OnCreditPurchaseInitiated(ctx, charge) - if err != nil { - return creditpurchase.Charge{}, err - } + charge, err = transaction.Run(ctx, s.adapter, func(ctx context.Context) (creditpurchase.Charge, error) { + ledgerTransactionGroupReference, err := s.handler.OnCreditPurchaseInitiated(ctx, charge) + if err != nil { + return creditpurchase.Charge{}, err + } - charge.State.CreditGrantRealization = &ledgertransaction.TimedGroupReference{ - GroupReference: ledgerTransactionGroupReference, - Time: clock.Now(), - } + charge.State.CreditGrantRealization = &ledgertransaction.TimedGroupReference{ + GroupReference: ledgerTransactionGroupReference, + Time: clock.Now(), + } + + if ledgerTransactionGroupReference.TransactionGroupID != "" { + if err := s.lineage.BackfillAdvanceLineageSegments(ctx, lineage.BackfillAdvanceLineageSegmentsInput{ + Namespace: charge.Namespace, + CustomerID: charge.Intent.CustomerID, + Currency: charge.Intent.Currency, + Amount: charge.Intent.CreditAmount, + BackingTransactionGroupID: ledgerTransactionGroupReference.TransactionGroupID, + }); err != nil { + return creditpurchase.Charge{}, err + } + } + + charge.Status = meta.ChargeStatusActive - charge.Status = meta.ChargeStatusActive + updatedCharge, err := s.adapter.UpdateCharge(ctx, charge) + if err != nil { + return creditpurchase.Charge{}, err + } - charge, err = s.adapter.UpdateCharge(ctx, charge) + return updatedCharge, nil + }) if err != nil { return creditpurchase.Charge{}, err } diff --git a/openmeter/billing/charges/creditpurchase/service/invoice.go b/openmeter/billing/charges/creditpurchase/service/invoice.go index a2dc9cbce3..980b40a2b7 100644 --- a/openmeter/billing/charges/creditpurchase/service/invoice.go +++ b/openmeter/billing/charges/creditpurchase/service/invoice.go @@ -6,26 +6,43 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/payment" "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/framework/transaction" ) func (s *service) PostInvoiceDraftCreated(ctx context.Context, charge creditpurchase.Charge, lineWithHeader billing.StandardLineWithInvoiceHeader) error { - ledgerTransactionGroupReference, err := s.handler.OnCreditPurchaseInitiated(ctx, charge) - if err != nil { + return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error { + ledgerTransactionGroupReference, err := s.handler.OnCreditPurchaseInitiated(ctx, charge) + if err != nil { + return err + } + + charge.State.CreditGrantRealization = &ledgertransaction.TimedGroupReference{ + GroupReference: ledgerTransactionGroupReference, + Time: clock.Now(), + } + + if ledgerTransactionGroupReference.TransactionGroupID != "" { + if err := s.lineage.BackfillAdvanceLineageSegments(ctx, lineage.BackfillAdvanceLineageSegmentsInput{ + Namespace: charge.Namespace, + CustomerID: charge.Intent.CustomerID, + Currency: charge.Intent.Currency, + Amount: charge.Intent.CreditAmount, + BackingTransactionGroupID: ledgerTransactionGroupReference.TransactionGroupID, + }); err != nil { + return err + } + } + + charge.Status = meta.ChargeStatusActive + + _, err = s.adapter.UpdateCharge(ctx, charge) return err - } - - charge.State.CreditGrantRealization = &ledgertransaction.TimedGroupReference{ - GroupReference: ledgerTransactionGroupReference, - Time: clock.Now(), - } - charge.Status = meta.ChargeStatusActive - - _, err = s.adapter.UpdateCharge(ctx, charge) - return err + }) } // PostInvoicePaymentAuthorized is called when the invoice is approved/issued. diff --git a/openmeter/billing/charges/creditpurchase/service/service.go b/openmeter/billing/charges/creditpurchase/service/service.go index 9ba5a8a865..fd023e3f1d 100644 --- a/openmeter/billing/charges/creditpurchase/service/service.go +++ b/openmeter/billing/charges/creditpurchase/service/service.go @@ -4,12 +4,14 @@ import ( "errors" "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" ) type Config struct { Adapter creditpurchase.Adapter Handler creditpurchase.Handler + Lineage lineage.Service MetaAdapter meta.Adapter } @@ -24,6 +26,10 @@ func (c Config) Validate() error { errs = append(errs, errors.New("credit purchase handler cannot be null")) } + if c.Lineage == nil { + errs = append(errs, errors.New("lineage service cannot be null")) + } + if c.MetaAdapter == nil { errs = append(errs, errors.New("meta adapter cannot be null")) } @@ -39,6 +45,7 @@ func New(config Config) (creditpurchase.Service, error) { return &service{ adapter: config.Adapter, handler: config.Handler, + lineage: config.Lineage, metaAdapter: config.MetaAdapter, }, nil } @@ -47,4 +54,5 @@ type service struct { adapter creditpurchase.Adapter metaAdapter meta.Adapter handler creditpurchase.Handler + lineage lineage.Service } diff --git a/openmeter/billing/charges/flatfee/adapter/charge.go b/openmeter/billing/charges/flatfee/adapter/charge.go index 774a4d6bc1..ba1f18130e 100644 --- a/openmeter/billing/charges/flatfee/adapter/charge.go +++ b/openmeter/billing/charges/flatfee/adapter/charge.go @@ -144,6 +144,7 @@ func (a *adapter) CreateCharges(ctx context.Context, in flatfee.CreateChargesInp } out = append(out, charge) } + return out, nil }) } @@ -172,9 +173,14 @@ func (a *adapter) GetByIDs(ctx context.Context, input flatfee.GetByIDsInput) ([] return nil, err } - return slicesx.MapWithErr(entitiesInOrder, func(entity *db.ChargeFlatFee) (flatfee.Charge, error) { + out, err := slicesx.MapWithErr(entitiesInOrder, func(entity *db.ChargeFlatFee) (flatfee.Charge, error) { return MapChargeFlatFeeFromDB(entity, input.Expands) }) + if err != nil { + return nil, err + } + + return out, nil }) } @@ -201,7 +207,12 @@ func (a *adapter) GetByID(ctx context.Context, input flatfee.GetByIDInput) (flat return flatfee.Charge{}, fmt.Errorf("querying flat fee charge [id=%s]: %w", input.ChargeID, err) } - return MapChargeFlatFeeFromDB(entity, input.Expands) + charge, err := MapChargeFlatFeeFromDB(entity, input.Expands) + if err != nil { + return flatfee.Charge{}, err + } + + return charge, nil }) } diff --git a/openmeter/billing/charges/flatfee/adapter/credits.go b/openmeter/billing/charges/flatfee/adapter/credits.go index b288f5a07a..4b61ef0372 100644 --- a/openmeter/billing/charges/flatfee/adapter/credits.go +++ b/openmeter/billing/charges/flatfee/adapter/credits.go @@ -32,8 +32,13 @@ func (a *adapter) CreateCreditAllocations(ctx context.Context, chargeID meta.Cha return creditrealization.Realizations{}, err } - return slicesx.MapWithErr(dbEntities, func(entity *db.ChargeFlatFeeCreditAllocations) (creditrealization.Realization, error) { + realizations, err := slicesx.MapWithErr(dbEntities, func(entity *db.ChargeFlatFeeCreditAllocations) (creditrealization.Realization, error) { return creditrealization.MapFromDB(entity), nil }) + if err != nil { + return creditrealization.Realizations{}, err + } + + return realizations, nil }) } diff --git a/openmeter/billing/charges/flatfee/handler.go b/openmeter/billing/charges/flatfee/handler.go index e05672ceaf..a96083409a 100644 --- a/openmeter/billing/charges/flatfee/handler.go +++ b/openmeter/billing/charges/flatfee/handler.go @@ -8,6 +8,7 @@ import ( "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/billing/models/totals" @@ -87,7 +88,8 @@ type CreditsOnlyUsageAccruedCorrectionInput struct { Charge Charge `json:"charge"` AllocateAt time.Time `json:"allocateAt"` - Corrections creditrealization.CorrectionRequest `json:"corrections"` + Corrections creditrealization.CorrectionRequest `json:"corrections"` + LineageSegmentsByRealization lineage.ActiveSegmentsByRealizationID `json:"-"` } func (i CreditsOnlyUsageAccruedCorrectionInput) Validate() error { diff --git a/openmeter/billing/charges/flatfee/service/creditsonly.go b/openmeter/billing/charges/flatfee/service/creditsonly.go index 65cc38b5f1..2c87df1c74 100644 --- a/openmeter/billing/charges/flatfee/service/creditsonly.go +++ b/openmeter/billing/charges/flatfee/service/creditsonly.go @@ -153,7 +153,7 @@ func (s *CreditsOnlyStateMachine) AllocateCredits(ctx context.Context) error { } if len(creditAllocations) > 0 { - realizations, err := s.Adapter.CreateCreditAllocations(ctx, s.Charge.GetChargeID(), creditAllocations.AsCreateInputs()) + realizations, err := s.Service.createCreditAllocations(ctx, s.Charge, creditAllocations.AsCreateInputs()) if err != nil { return fmt.Errorf("create credit allocations: %w", err) } @@ -222,12 +222,21 @@ func (s *CreditsOnlyStateMachine) DeleteCharge(ctx context.Context, policy meta. return fmt.Errorf("get currency calculator: %w", err) } + realizationIDs := lo.Map(s.Charge.State.CreditRealizations, func(realization creditrealization.Realization, _ int) string { + return realization.ID + }) + lineageSegmentsByRealization, err := s.Service.lineage.LoadActiveSegmentsByRealizationID(ctx, s.Charge.Namespace, realizationIDs) + if err != nil { + return fmt.Errorf("load active lineage segments: %w", err) + } + // Let's reverse the credit allocations corrections, err := s.Charge.State.CreditRealizations.CorrectAll(currencyCalculator, func(req creditrealization.CorrectionRequest) (creditrealization.CreateCorrectionInputs, error) { return s.Service.handler.OnCreditsOnlyUsageAccruedCorrection(ctx, flatfee.CreditsOnlyUsageAccruedCorrectionInput{ - Charge: s.Charge, - AllocateAt: clock.Now(), - Corrections: req, + Charge: s.Charge, + AllocateAt: clock.Now(), + Corrections: req, + LineageSegmentsByRealization: lineageSegmentsByRealization, }) }) if err != nil { @@ -235,7 +244,7 @@ func (s *CreditsOnlyStateMachine) DeleteCharge(ctx context.Context, policy meta. } if len(corrections) > 0 { - if _, err := s.Adapter.CreateCreditAllocations(ctx, s.Charge.GetChargeID(), corrections); err != nil { + if _, err := s.Service.createCreditAllocations(ctx, s.Charge, corrections); err != nil { return fmt.Errorf("create credit corrections: %w", err) } } diff --git a/openmeter/billing/charges/flatfee/service/invoice.go b/openmeter/billing/charges/flatfee/service/invoice.go index 1d2b511b39..bb361ebda3 100644 --- a/openmeter/billing/charges/flatfee/service/invoice.go +++ b/openmeter/billing/charges/flatfee/service/invoice.go @@ -44,7 +44,7 @@ func (s *service) PostLineAssignedToInvoice(ctx context.Context, charge flatfee. } // TODO: If we want we can bulk insert these into the database for better performance (for now it's fine) - realizations, err := s.adapter.CreateCreditAllocations(ctx, charge.GetChargeID(), creditAllocations.AsCreateInputs()) + realizations, err := s.createCreditAllocations(ctx, charge, creditAllocations.AsCreateInputs()) if err != nil { return nil, fmt.Errorf("creating credit realizations: %w", err) } diff --git a/openmeter/billing/charges/flatfee/service/lineage.go b/openmeter/billing/charges/flatfee/service/lineage.go new file mode 100644 index 0000000000..c86dccaf9a --- /dev/null +++ b/openmeter/billing/charges/flatfee/service/lineage.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" +) + +func (s *service) createCreditAllocations(ctx context.Context, charge flatfee.Charge, creditAllocations creditrealization.CreateInputs) (creditrealization.Realizations, error) { + realizations, err := s.adapter.CreateCreditAllocations(ctx, charge.GetChargeID(), creditAllocations) + if err != nil { + return creditrealization.Realizations{}, err + } + + if err := s.lineage.CreateInitialLineages(ctx, lineage.CreateInitialLineagesInput{ + Namespace: charge.Namespace, + ChargeID: charge.ID, + CustomerID: charge.Intent.CustomerID, + Currency: charge.Intent.Currency, + Realizations: realizations, + }); err != nil { + return creditrealization.Realizations{}, fmt.Errorf("create initial credit realization lineages: %w", err) + } + + if err := s.lineage.PersistCorrectionLineageSegments(ctx, lineage.PersistCorrectionLineageSegmentsInput{ + Namespace: charge.Namespace, + Realizations: realizations, + }); err != nil { + return creditrealization.Realizations{}, fmt.Errorf("persist correction lineage segments: %w", err) + } + + return realizations, nil +} diff --git a/openmeter/billing/charges/flatfee/service/service.go b/openmeter/billing/charges/flatfee/service/service.go index 6b44446265..0210816ec8 100644 --- a/openmeter/billing/charges/flatfee/service/service.go +++ b/openmeter/billing/charges/flatfee/service/service.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/pkg/framework/lockr" ) @@ -11,6 +12,7 @@ import ( type Config struct { Adapter flatfee.Adapter Handler flatfee.Handler + Lineage lineage.Service MetaAdapter meta.Adapter Locker *lockr.Locker } @@ -26,6 +28,10 @@ func (c Config) Validate() error { errs = append(errs, errors.New("handler cannot be null")) } + if c.Lineage == nil { + errs = append(errs, errors.New("lineage service cannot be null")) + } + if c.MetaAdapter == nil { errs = append(errs, errors.New("meta adapter cannot be null")) } @@ -45,6 +51,7 @@ func New(config Config) (flatfee.Service, error) { return &service{ adapter: config.Adapter, handler: config.Handler, + lineage: config.Lineage, metaAdapter: config.MetaAdapter, locker: config.Locker, }, nil @@ -53,6 +60,7 @@ func New(config Config) (flatfee.Service, error) { type service struct { adapter flatfee.Adapter handler flatfee.Handler + lineage lineage.Service metaAdapter meta.Adapter locker *lockr.Locker } diff --git a/openmeter/billing/charges/lineage/adapter/adapter.go b/openmeter/billing/charges/lineage/adapter/adapter.go new file mode 100644 index 0000000000..00fd24bd18 --- /dev/null +++ b/openmeter/billing/charges/lineage/adapter/adapter.go @@ -0,0 +1,62 @@ +package adapter + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/framework/transaction" +) + +type Config struct { + Client *entdb.Client +} + +func (c Config) Validate() error { + if c.Client == nil { + return errors.New("ent client is required") + } + + return nil +} + +func New(config Config) (lineage.Adapter, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &adapter{ + db: config.Client, + }, nil +} + +type adapter struct { + db *entdb.Client +} + +func (a *adapter) Tx(ctx context.Context) (context.Context, transaction.Driver, error) { + txCtx, rawConfig, eDriver, err := a.db.HijackTx(ctx, &sql.TxOptions{ + ReadOnly: false, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to hijack transaction: %w", err) + } + + return txCtx, entutils.NewTxDriver(eDriver, rawConfig), nil +} + +func (a *adapter) WithTx(ctx context.Context, tx *entutils.TxDriver) *adapter { + txDB := entdb.NewTxClientFromRawConfig(ctx, *tx.GetConfig()) + + return &adapter{ + db: txDB.Client(), + } +} + +func (a *adapter) Self() *adapter { + return a +} diff --git a/openmeter/billing/charges/lineage/adapter/lineage.go b/openmeter/billing/charges/lineage/adapter/lineage.go new file mode 100644 index 0000000000..b60f1fdbc5 --- /dev/null +++ b/openmeter/billing/charges/lineage/adapter/lineage.go @@ -0,0 +1,243 @@ +package adapter + +import ( + "context" + "fmt" + "time" + + "github.com/oklog/ulid/v2" + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/entutils" +) + +func LoadActiveSegmentsByRealizationID( + ctx context.Context, + db *entdb.Client, + namespace string, + realizationIDs []string, +) (lineage.ActiveSegmentsByRealizationID, error) { + repo := &adapter{db: db} + + return entutils.TransactingRepo(ctx, repo, func(ctx context.Context, tx *adapter) (lineage.ActiveSegmentsByRealizationID, error) { + if len(realizationIDs) == 0 { + return lineage.ActiveSegmentsByRealizationID{}, nil + } + + lineages, err := tx.db.CreditRealizationLineage.Query(). + Where( + creditrealizationlineage.Namespace(namespace), + creditrealizationlineage.RootRealizationIDIn(realizationIDs...), + ). + WithSegments(func(q *entdb.CreditRealizationLineageSegmentQuery) { + q.Where(creditrealizationlineagesegment.ClosedAtIsNil()). + Order(creditrealizationlineagesegment.ByCreatedAt()) + }). + All(ctx) + if err != nil { + return nil, err + } + + return lo.SliceToMap(lineages, func(entry *entdb.CreditRealizationLineage) (string, []lineage.Segment) { + return entry.RootRealizationID, lo.Map(entry.Edges.Segments, func(segment *entdb.CreditRealizationLineageSegment, _ int) lineage.Segment { + return lineage.Segment{ + ID: segment.ID, + LineageID: segment.LineageID, + Amount: segment.Amount, + State: segment.State, + BackingTransactionGroupID: segment.BackingTransactionGroupID, + } + }) + }), nil + }) +} + +func (a *adapter) LoadActiveSegmentsByRealizationID( + ctx context.Context, + namespace string, + realizationIDs []string, +) (lineage.ActiveSegmentsByRealizationID, error) { + return LoadActiveSegmentsByRealizationID(ctx, a.db, namespace, realizationIDs) +} + +func (a *adapter) CreateLineages(ctx context.Context, input lineage.CreateLineagesInput) error { + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + rootCreates := make([]*entdb.CreditRealizationLineageCreate, 0, len(input.Specs)) + segmentCreates := make([]*entdb.CreditRealizationLineageSegmentCreate, 0, len(input.Specs)) + + for _, spec := range input.Specs { + rootCreates = append(rootCreates, tx.db.CreditRealizationLineage.Create(). + SetID(spec.LineageID). + SetNamespace(input.Namespace). + SetChargeID(input.ChargeID). + SetRootRealizationID(spec.RootRealizationID). + SetCustomerID(input.CustomerID). + SetCurrency(input.Currency). + SetOriginKind(spec.OriginKind), + ) + segmentCreates = append(segmentCreates, tx.db.CreditRealizationLineageSegment.Create(). + SetLineageID(spec.LineageID). + SetAmount(spec.Amount). + SetState(spec.InitialState), + ) + } + + if _, err := tx.db.CreditRealizationLineage.CreateBulk(rootCreates...).Save(ctx); err != nil { + return fmt.Errorf("create credit realization lineages: %w", err) + } + if _, err := tx.db.CreditRealizationLineageSegment.CreateBulk(segmentCreates...).Save(ctx); err != nil { + return fmt.Errorf("create initial credit realization lineage segments: %w", err) + } + + return nil + }) +} + +func (a *adapter) LockCorrectionLineages(ctx context.Context, namespace string, realizationIDs []string) ([]lineage.Lineage, error) { + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) ([]lineage.Lineage, error) { + if _, err := entutils.GetDriverFromContext(ctx); err != nil { + return nil, fmt.Errorf("lock correction lineages must be called in a transaction: %w", err) + } + + lineages, err := tx.db.CreditRealizationLineage.Query(). + Where( + creditrealizationlineage.Namespace(namespace), + creditrealizationlineage.RootRealizationIDIn(realizationIDs...), + ). + WithSegments(func(q *entdb.CreditRealizationLineageSegmentQuery) { + q.Where(creditrealizationlineagesegment.ClosedAtIsNil()). + Order(creditrealizationlineagesegment.ByCreatedAt()) + }). + Order(creditrealizationlineage.ByCreatedAt()). + ForUpdate(). + All(ctx) + if err != nil { + return nil, err + } + + return lo.Map(lineages, mapLineage), nil + }) +} + +func (a *adapter) LockAdvanceLineagesForBackfill(ctx context.Context, namespace string, customerID string, currency currencyx.Code) ([]lineage.Lineage, error) { + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) ([]lineage.Lineage, error) { + if _, err := entutils.GetDriverFromContext(ctx); err != nil { + return nil, fmt.Errorf("lock advance lineages for backfill must be called in a transaction: %w", err) + } + + lineages, err := tx.db.CreditRealizationLineage.Query(). + Where( + creditrealizationlineage.Namespace(namespace), + creditrealizationlineage.CustomerIDEQ(customerID), + creditrealizationlineage.CurrencyEQ(currency), + creditrealizationlineage.HasSegmentsWith( + creditrealizationlineagesegment.ClosedAtIsNil(), + creditrealizationlineagesegment.StateEQ(creditrealization.LineageSegmentStateAdvanceUncovered), + ), + ). + Order(creditrealizationlineage.ByCreatedAt()). + ForUpdate(). + All(ctx) + if err != nil { + return nil, err + } + + return lo.Map(lineages, func(entry *entdb.CreditRealizationLineage, _ int) lineage.Lineage { + return lineage.Lineage{ + ID: entry.ID, + ChargeID: entry.ChargeID, + RootRealizationID: entry.RootRealizationID, + CustomerID: entry.CustomerID, + Currency: entry.Currency, + OriginKind: entry.OriginKind, + } + }), nil + }) +} + +func (a *adapter) ListActiveSegments(ctx context.Context, input lineage.ListActiveSegmentsInput) ([]lineage.Segment, error) { + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) ([]lineage.Segment, error) { + query := tx.db.CreditRealizationLineageSegment.Query(). + Where( + creditrealizationlineagesegment.ClosedAtIsNil(), + creditrealizationlineagesegment.LineageIDIn(input.LineageIDs...), + ). + Order(creditrealizationlineagesegment.ByCreatedAt()) + + if input.State != nil { + query = query.Where(creditrealizationlineagesegment.StateEQ(*input.State)) + } + + segments, err := query.All(ctx) + if err != nil { + return nil, err + } + + return lo.Map(segments, func(segment *entdb.CreditRealizationLineageSegment, _ int) lineage.Segment { + return mapSegment(segment) + }), nil + }) +} + +func (a *adapter) CloseSegment(ctx context.Context, segmentID string, closedAt time.Time) error { + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + if _, err := tx.db.CreditRealizationLineageSegment.UpdateOneID(segmentID). + SetClosedAt(closedAt). + Save(ctx); err != nil { + return err + } + + return nil + }) +} + +func (a *adapter) CreateSegment(ctx context.Context, input lineage.CreateSegmentInput) error { + if err := input.Validate(); err != nil { + return fmt.Errorf("create lineage segment: %w", err) + } + + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + create := tx.db.CreditRealizationLineageSegment.Create(). + SetID(ulid.Make().String()). + SetLineageID(input.LineageID). + SetAmount(input.Amount). + SetState(input.State) + + if input.BackingTransactionGroupID != nil { + create = create.SetBackingTransactionGroupID(*input.BackingTransactionGroupID) + } + + _, err := create.Save(ctx) + return err + }) +} + +func mapLineage(entry *entdb.CreditRealizationLineage, _ int) lineage.Lineage { + return lineage.Lineage{ + ID: entry.ID, + ChargeID: entry.ChargeID, + RootRealizationID: entry.RootRealizationID, + CustomerID: entry.CustomerID, + Currency: entry.Currency, + OriginKind: entry.OriginKind, + Segments: lo.Map(entry.Edges.Segments, func(segment *entdb.CreditRealizationLineageSegment, _ int) lineage.Segment { + return mapSegment(segment) + }), + } +} + +func mapSegment(segment *entdb.CreditRealizationLineageSegment) lineage.Segment { + return lineage.Segment{ + ID: segment.ID, + LineageID: segment.LineageID, + Amount: segment.Amount, + State: segment.State, + BackingTransactionGroupID: segment.BackingTransactionGroupID, + } +} diff --git a/openmeter/billing/charges/lineage/lineage.go b/openmeter/billing/charges/lineage/lineage.go new file mode 100644 index 0000000000..358f1526e5 --- /dev/null +++ b/openmeter/billing/charges/lineage/lineage.go @@ -0,0 +1,59 @@ +package lineage + +import ( + "errors" + "fmt" + "sort" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" +) + +func SortCorrectionPersistSegments(segments []Segment) []Segment { + sorted := append([]Segment(nil), segments...) + sort.SliceStable(sorted, func(i, j int) bool { + precedence := func(state creditrealization.LineageSegmentState) int { + switch state { + case creditrealization.LineageSegmentStateAdvanceBackfilled: + return 0 + case creditrealization.LineageSegmentStateAdvanceUncovered: + return 1 + case creditrealization.LineageSegmentStateRealCredit: + return 2 + default: + return 3 + } + } + + return precedence(sorted[i].State) < precedence(sorted[j].State) + }) + + return sorted +} + +func MinDecimal(a, b alpacadecimal.Decimal) alpacadecimal.Decimal { + if a.GreaterThan(b) { + return b + } + + return a +} + +func (s Segment) Validate() error { + var errs []error + + if !s.Amount.IsPositive() { + errs = append(errs, errors.New("amount must be positive")) + } + + if err := s.State.Validate(); err != nil { + errs = append(errs, fmt.Errorf("state: %w", err)) + } + + if s.State == creditrealization.LineageSegmentStateAdvanceBackfilled && (s.BackingTransactionGroupID == nil || *s.BackingTransactionGroupID == "") { + errs = append(errs, errors.New("backing transaction group id is required for advance_backfilled")) + } + + return errors.Join(errs...) +} diff --git a/openmeter/billing/charges/lineage/service.go b/openmeter/billing/charges/lineage/service.go new file mode 100644 index 0000000000..a296d9c5a4 --- /dev/null +++ b/openmeter/billing/charges/lineage/service.go @@ -0,0 +1,185 @@ +package lineage + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/entutils" +) + +type Service interface { + CreateInitialLineages(ctx context.Context, input CreateInitialLineagesInput) error + LoadActiveSegmentsByRealizationID(ctx context.Context, namespace string, realizationIDs []string) (ActiveSegmentsByRealizationID, error) + PersistCorrectionLineageSegments(ctx context.Context, input PersistCorrectionLineageSegmentsInput) error + BackfillAdvanceLineageSegments(ctx context.Context, input BackfillAdvanceLineageSegmentsInput) error +} + +type Adapter interface { + entutils.TxCreator + + CreateLineages(ctx context.Context, input CreateLineagesInput) error + LoadActiveSegmentsByRealizationID(ctx context.Context, namespace string, realizationIDs []string) (ActiveSegmentsByRealizationID, error) + LockCorrectionLineages(ctx context.Context, namespace string, realizationIDs []string) ([]Lineage, error) + LockAdvanceLineagesForBackfill(ctx context.Context, namespace string, customerID string, currency currencyx.Code) ([]Lineage, error) + ListActiveSegments(ctx context.Context, input ListActiveSegmentsInput) ([]Segment, error) + CloseSegment(ctx context.Context, segmentID string, closedAt time.Time) error + CreateSegment(ctx context.Context, input CreateSegmentInput) error +} + +type CreateInitialLineagesInput struct { + Namespace string + ChargeID string + CustomerID string + Currency currencyx.Code + Realizations creditrealization.Realizations +} + +func (i CreateInitialLineagesInput) Validate() error { + var errs []error + + if i.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + if i.ChargeID == "" { + errs = append(errs, errors.New("charge id is required")) + } + if i.CustomerID == "" { + errs = append(errs, errors.New("customer id is required")) + } + if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + if err := i.Realizations.Validate(); err != nil { + errs = append(errs, fmt.Errorf("realizations: %w", err)) + } + + return errors.Join(errs...) +} + +type PersistCorrectionLineageSegmentsInput struct { + Namespace string + Realizations creditrealization.Realizations +} + +func (i PersistCorrectionLineageSegmentsInput) Validate() error { + var errs []error + + if i.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + + for idx, realization := range i.Realizations { + if realization.Type != creditrealization.TypeCorrection { + continue + } + + if realization.CorrectsRealizationID == nil || *realization.CorrectsRealizationID == "" { + errs = append(errs, fmt.Errorf("realizations[%d]: corrects realization id is required for corrections", idx)) + } + } + + return errors.Join(errs...) +} + +type BackfillAdvanceLineageSegmentsInput struct { + Namespace string + CustomerID string + Currency currencyx.Code + Amount alpacadecimal.Decimal + BackingTransactionGroupID string +} + +func (i BackfillAdvanceLineageSegmentsInput) Validate() error { + var errs []error + + if i.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + if i.CustomerID == "" { + errs = append(errs, errors.New("customer id is required")) + } + if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + if !i.Amount.IsPositive() { + errs = append(errs, errors.New("amount must be positive")) + } + if i.BackingTransactionGroupID == "" { + errs = append(errs, errors.New("backing transaction group id is required")) + } + + return errors.Join(errs...) +} + +type CreateLineagesInput struct { + Namespace string + ChargeID string + CustomerID string + Currency currencyx.Code + Specs []creditrealization.InitialLineageSpec +} + +type ListActiveSegmentsInput struct { + LineageIDs []string + State *creditrealization.LineageSegmentState +} + +type CreateSegmentInput struct { + LineageID string + Amount alpacadecimal.Decimal + State creditrealization.LineageSegmentState + BackingTransactionGroupID *string +} + +func (i CreateSegmentInput) Validate() error { + var errs []error + + if i.LineageID == "" { + errs = append(errs, errors.New("lineage id is required")) + } + if !i.Amount.IsPositive() { + errs = append(errs, errors.New("amount must be positive")) + } + if err := i.State.Validate(); err != nil { + errs = append(errs, fmt.Errorf("state: %w", err)) + } + + switch i.State { + case creditrealization.LineageSegmentStateAdvanceBackfilled: + if i.BackingTransactionGroupID == nil || *i.BackingTransactionGroupID == "" { + errs = append(errs, errors.New("backing transaction group id is required for advance_backfilled segments")) + } + default: + if i.BackingTransactionGroupID != nil && *i.BackingTransactionGroupID == "" { + errs = append(errs, errors.New("backing transaction group id must not be empty when provided")) + } + } + + return errors.Join(errs...) +} + +type Lineage struct { + ID string + ChargeID string + RootRealizationID string + CustomerID string + Currency currencyx.Code + OriginKind creditrealization.LineageOriginKind + Segments []Segment +} + +type Segment struct { + ID string + LineageID string + Amount alpacadecimal.Decimal + State creditrealization.LineageSegmentState + BackingTransactionGroupID *string +} + +type ActiveSegmentsByRealizationID map[string][]Segment diff --git a/openmeter/billing/charges/lineage/service/service.go b/openmeter/billing/charges/lineage/service/service.go new file mode 100644 index 0000000000..c41f1e2671 --- /dev/null +++ b/openmeter/billing/charges/lineage/service/service.go @@ -0,0 +1,229 @@ +package service + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/framework/transaction" +) + +type Config struct { + Adapter lineage.Adapter +} + +func (c Config) Validate() error { + if c.Adapter == nil { + return errors.New("adapter cannot be null") + } + + return nil +} + +func New(config Config) (lineage.Service, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &service{ + adapter: config.Adapter, + }, nil +} + +type service struct { + adapter lineage.Adapter +} + +func (s *service) CreateInitialLineages(ctx context.Context, input lineage.CreateInitialLineagesInput) error { + if err := input.Validate(); err != nil { + return err + } + + return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error { + specs, err := creditrealization.InitialLineageSpecs(input.Realizations) + if err != nil { + return fmt.Errorf("build initial credit realization lineage specs: %w", err) + } + if len(specs) == 0 { + return nil + } + + return s.adapter.CreateLineages(ctx, lineage.CreateLineagesInput{ + Namespace: input.Namespace, + ChargeID: input.ChargeID, + CustomerID: input.CustomerID, + Currency: input.Currency, + Specs: specs, + }) + }) +} + +func (s *service) LoadActiveSegmentsByRealizationID(ctx context.Context, namespace string, realizationIDs []string) (lineage.ActiveSegmentsByRealizationID, error) { + if len(realizationIDs) == 0 { + return lineage.ActiveSegmentsByRealizationID{}, nil + } + + segmentsByRealizationID, err := s.adapter.LoadActiveSegmentsByRealizationID(ctx, namespace, realizationIDs) + if err != nil { + return nil, fmt.Errorf("load active lineage segments: %w", err) + } + + return segmentsByRealizationID, nil +} + +func (s *service) PersistCorrectionLineageSegments(ctx context.Context, input lineage.PersistCorrectionLineageSegmentsInput) error { + if err := input.Validate(); err != nil { + return err + } + + return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error { + correctionAmountsByRealizationID := make(map[string]alpacadecimal.Decimal, len(input.Realizations)) + correctionOrder := make([]string, 0) + + for _, realization := range input.Realizations { + if realization.Type != creditrealization.TypeCorrection || realization.CorrectsRealizationID == nil { + continue + } + + correctsRealizationID := *realization.CorrectsRealizationID + if _, ok := correctionAmountsByRealizationID[correctsRealizationID]; !ok { + correctionOrder = append(correctionOrder, correctsRealizationID) + } + + correctionAmountsByRealizationID[correctsRealizationID] = correctionAmountsByRealizationID[correctsRealizationID].Add(realization.Amount.Abs()) + } + + if len(correctionOrder) == 0 { + return nil + } + + lineages, err := s.adapter.LockCorrectionLineages(ctx, input.Namespace, correctionOrder) + if err != nil { + return fmt.Errorf("lock lineages for correction persistence: %w", err) + } + + lineagesByRealizationID := make(map[string]lineage.Lineage, len(lineages)) + for _, entry := range lineages { + lineagesByRealizationID[entry.RootRealizationID] = entry + } + + now := clock.Now().Truncate(time.Microsecond) + + for _, realizationID := range correctionOrder { + entry, ok := lineagesByRealizationID[realizationID] + if !ok { + continue + } + + remaining := correctionAmountsByRealizationID[realizationID] + for _, segment := range lineage.SortCorrectionPersistSegments(entry.Segments) { + if !remaining.IsPositive() { + break + } + + consumedAmount := lineage.MinDecimal(segment.Amount, remaining) + if !consumedAmount.IsPositive() { + continue + } + + if err := s.adapter.CloseSegment(ctx, segment.ID, now); err != nil { + return fmt.Errorf("close active lineage segment %s: %w", segment.ID, err) + } + + remainder := segment.Amount.Sub(consumedAmount) + if remainder.IsPositive() { + if err := s.adapter.CreateSegment(ctx, lineage.CreateSegmentInput{ + LineageID: segment.LineageID, + Amount: remainder, + State: segment.State, + BackingTransactionGroupID: segment.BackingTransactionGroupID, + }); err != nil { + return fmt.Errorf("create lineage segment remainder for %s: %w", segment.ID, err) + } + } + + remaining = remaining.Sub(consumedAmount) + } + + if remaining.IsPositive() { + return fmt.Errorf("correction amount %s exceeds active lineage coverage for realization %s", remaining.String(), realizationID) + } + } + + return nil + }) +} + +func (s *service) BackfillAdvanceLineageSegments(ctx context.Context, input lineage.BackfillAdvanceLineageSegmentsInput) error { + if err := input.Validate(); err != nil { + return err + } + + return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error { + lineages, err := s.adapter.LockAdvanceLineagesForBackfill(ctx, input.Namespace, input.CustomerID, input.Currency) + if err != nil { + return fmt.Errorf("lock advance lineages for backfill: %w", err) + } + if len(lineages) == 0 { + return nil + } + + lineageIDs := make([]string, 0, len(lineages)) + for _, entry := range lineages { + lineageIDs = append(lineageIDs, entry.ID) + } + + state := creditrealization.LineageSegmentStateAdvanceUncovered + segments, err := s.adapter.ListActiveSegments(ctx, lineage.ListActiveSegmentsInput{ + LineageIDs: lineageIDs, + State: &state, + }) + if err != nil { + return fmt.Errorf("query active uncovered advance lineage segments: %w", err) + } + + now := clock.Now().Truncate(time.Microsecond) + remaining := input.Amount + + for _, segment := range segments { + if !remaining.IsPositive() { + break + } + + coveredAmount := lineage.MinDecimal(segment.Amount, remaining) + if err := s.adapter.CloseSegment(ctx, segment.ID, now); err != nil { + return fmt.Errorf("close uncovered advance lineage segment %s: %w", segment.ID, err) + } + + remainder := segment.Amount.Sub(coveredAmount) + if remainder.IsPositive() { + if err := s.adapter.CreateSegment(ctx, lineage.CreateSegmentInput{ + LineageID: segment.LineageID, + Amount: remainder, + State: creditrealization.LineageSegmentStateAdvanceUncovered, + }); err != nil { + return fmt.Errorf("create uncovered advance lineage remainder for segment %s: %w", segment.ID, err) + } + } + + if err := s.adapter.CreateSegment(ctx, lineage.CreateSegmentInput{ + LineageID: segment.LineageID, + Amount: coveredAmount, + State: creditrealization.LineageSegmentStateAdvanceBackfilled, + BackingTransactionGroupID: &input.BackingTransactionGroupID, + }); err != nil { + return fmt.Errorf("create backfilled advance lineage segment for segment %s: %w", segment.ID, err) + } + + remaining = remaining.Sub(coveredAmount) + } + + return nil + }) +} diff --git a/openmeter/billing/charges/models/creditrealization/lineage.go b/openmeter/billing/charges/models/creditrealization/lineage.go new file mode 100644 index 0000000000..d2a7fdb871 --- /dev/null +++ b/openmeter/billing/charges/models/creditrealization/lineage.go @@ -0,0 +1,93 @@ +package creditrealization + +import ( + "fmt" + "slices" + + "github.com/openmeterio/openmeter/pkg/models" +) + +const AnnotationLineageOriginKind = "billing.credit_realization.lineage_origin_kind" + +type LineageOriginKind string + +const ( + LineageOriginKindRealCredit LineageOriginKind = "real_credit" + LineageOriginKindAdvance LineageOriginKind = "advance" +) + +func (k LineageOriginKind) Values() []string { + return []string{ + string(LineageOriginKindRealCredit), + string(LineageOriginKindAdvance), + } +} + +func (k LineageOriginKind) Validate() error { + if !slices.Contains(k.Values(), string(k)) { + return fmt.Errorf("invalid credit realization lineage origin kind: %s", k) + } + + return nil +} + +func LineageAnnotations(originKind LineageOriginKind) models.Annotations { + return models.Annotations{ + AnnotationLineageOriginKind: string(originKind), + } +} + +func LineageOriginKindFromAnnotations(annotations models.Annotations) (LineageOriginKind, error) { + originKind, ok := annotations.GetString(AnnotationLineageOriginKind) + if !ok { + return "", fmt.Errorf("missing credit realization lineage origin kind annotation") + } + + out := LineageOriginKind(originKind) + if err := out.Validate(); err != nil { + return "", err + } + + return out, nil +} + +type LineageSegmentState string + +const ( + // LineageSegmentStateRealCredit marks value that is still backed by the original + // real-credit source and has not passed through advance/backfill flows. + LineageSegmentStateRealCredit LineageSegmentState = "real_credit" + // LineageSegmentStateAdvanceUncovered marks value that was collected as advance-backed + // usage and is still not covered by a later credit purchase. + LineageSegmentStateAdvanceUncovered LineageSegmentState = "advance_uncovered" + // LineageSegmentStateAdvanceBackfilled marks value that was originally advance-backed + // usage but was later covered by a credit purchase. + LineageSegmentStateAdvanceBackfilled LineageSegmentState = "advance_backfilled" +) + +func (s LineageSegmentState) Values() []string { + return []string{ + string(LineageSegmentStateRealCredit), + string(LineageSegmentStateAdvanceUncovered), + string(LineageSegmentStateAdvanceBackfilled), + } +} + +func (s LineageSegmentState) Validate() error { + if !slices.Contains(s.Values(), string(s)) { + return fmt.Errorf("invalid credit realization lineage segment state: %s", s) + } + + return nil +} + +func InitialLineageSegmentState(originKind LineageOriginKind) LineageSegmentState { + switch originKind { + case LineageOriginKindRealCredit: + return LineageSegmentStateRealCredit + case LineageOriginKindAdvance: + return LineageSegmentStateAdvanceUncovered + default: + return "" + } +} diff --git a/openmeter/billing/charges/models/creditrealization/lineage_specs.go b/openmeter/billing/charges/models/creditrealization/lineage_specs.go new file mode 100644 index 0000000000..e0223a0183 --- /dev/null +++ b/openmeter/billing/charges/models/creditrealization/lineage_specs.go @@ -0,0 +1,46 @@ +package creditrealization + +import ( + "fmt" + + "github.com/alpacahq/alpacadecimal" + "github.com/oklog/ulid/v2" +) + +type InitialLineageSpec struct { + LineageID string + RootRealizationID string + OriginKind LineageOriginKind + InitialState LineageSegmentState + Amount alpacadecimal.Decimal +} + +func InitialLineageSpecs(realizations Realizations) ([]InitialLineageSpec, error) { + out := make([]InitialLineageSpec, 0, len(realizations)) + + for _, realization := range realizations { + if realization.Type != TypeAllocation { + continue + } + + originKind, err := LineageOriginKindFromAnnotations(realization.Annotations) + if err != nil { + continue + } + + initialState := InitialLineageSegmentState(originKind) + if err := initialState.Validate(); err != nil { + return nil, fmt.Errorf("realization %s initial lineage state: %w", realization.ID, err) + } + + out = append(out, InitialLineageSpec{ + LineageID: ulid.Make().String(), + RootRealizationID: realization.ID, + OriginKind: originKind, + InitialState: initialState, + Amount: realization.Amount, + }) + } + + return out, nil +} diff --git a/openmeter/billing/charges/service/base_test.go b/openmeter/billing/charges/service/base_test.go index 41e4e67e75..8bb2e1103f 100644 --- a/openmeter/billing/charges/service/base_test.go +++ b/openmeter/billing/charges/service/base_test.go @@ -17,6 +17,8 @@ import ( flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" flatfeelineengine "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/lineengine" flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" @@ -62,6 +64,16 @@ func (s *BaseSuite) SetupSuite() { }) s.NoError(err) + lineageAdapter, err := lineageadapter.New(lineageadapter.Config{ + Client: s.DBClient, + }) + s.NoError(err) + + lineageService, err := lineageservice.New(lineageservice.Config{ + Adapter: lineageAdapter, + }) + s.NoError(err) + flatFeeAdapter, err := flatfeeadapter.New(flatfeeadapter.Config{ Client: s.DBClient, Logger: slog.Default(), @@ -72,6 +84,7 @@ func (s *BaseSuite) SetupSuite() { flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ Adapter: flatFeeAdapter, Handler: s.FlatFeeTestHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, Locker: locker, }) @@ -96,6 +109,7 @@ func (s *BaseSuite) SetupSuite() { usageBasedService, err := usagebasedservice.New(usagebasedservice.Config{ Adapter: usageBasedAdapter, Handler: s.UsageBasedTestHandler, + Lineage: lineageService, Locker: locker, MetaAdapter: metaAdapter, CustomerOverrideService: s.BillingService, @@ -116,6 +130,7 @@ func (s *BaseSuite) SetupSuite() { creditPurchaseService, err := creditpurchaseservice.New(creditpurchaseservice.Config{ Adapter: creditPurchaseAdapter, Handler: s.CreditPurchaseTestHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, }) s.NoError(err) diff --git a/openmeter/billing/charges/service/lineage_test.go b/openmeter/billing/charges/service/lineage_test.go new file mode 100644 index 0000000000..1253658994 --- /dev/null +++ b/openmeter/billing/charges/service/lineage_test.go @@ -0,0 +1,413 @@ +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/invopop/gobl/currency" + "github.com/oklog/ulid/v2" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + lineage "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" + chargesmeta "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/customer" + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/datetime" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/framework/transaction" + "github.com/openmeterio/openmeter/pkg/timeutil" + billingtest "github.com/openmeterio/openmeter/test/billing" +) + +type CreditRealizationLineageTestSuite struct { + BaseSuite +} + +func TestCreditRealizationLineage(t *testing.T) { + suite.Run(t, new(CreditRealizationLineageTestSuite)) +} + +func (s *CreditRealizationLineageTestSuite) SetupSuite() { + s.BaseSuite.SetupSuite() +} + +func (s *CreditRealizationLineageTestSuite) TearDownTest() { + s.BaseSuite.TearDownTest() +} + +func (s *CreditRealizationLineageTestSuite) TestFlatFeeCreditOnlyAllocationCreatesInitialLineages() { + defer s.FlatFeeTestHandler.Reset() + + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-service-flatfee-credit-realization-lineage") + customInvoicing := s.SetupCustomInvoicing(ns) + cust := s.CreateTestCustomer(ns, "test-subject") + s.NotEmpty(cust.ID) + + _ = s.ProvisionBillingProfile(ctx, ns, customInvoicing.App.GetID(), + billingtest.WithProgressiveBilling(), + billingtest.WithCollectionInterval(datetime.MustParseDuration(s.T(), "PT1H")), + billingtest.WithManualApproval(), + ) + + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + + s.FlatFeeTestHandler.onCreditsOnlyUsageAccrued = func(ctx context.Context, input flatfee.OnCreditsOnlyUsageAccruedInput) (creditrealization.CreateAllocationInputs, error) { + return creditrealization.CreateAllocationInputs{ + { + ServicePeriod: input.Charge.Intent.ServicePeriod, + Amount: alpacadecimal.NewFromInt(20), + Annotations: creditrealization.LineageAnnotations(creditrealization.LineageOriginKindRealCredit), + LedgerTransaction: ledgertransaction.GroupReference{ + TransactionGroupID: ulid.Make().String(), + }, + }, + { + ServicePeriod: input.Charge.Intent.ServicePeriod, + Amount: alpacadecimal.NewFromInt(30), + Annotations: creditrealization.LineageAnnotations(creditrealization.LineageOriginKindAdvance), + LedgerTransaction: ledgertransaction.GroupReference{ + TransactionGroupID: ulid.Make().String(), + }, + }, + }, nil + } + + clock.FreezeTime(servicePeriod.From) + defer clock.UnFreeze() + + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: []charges.ChargeIntent{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: currencyx.Code(currency.USD), + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromInt(50), + PaymentTerm: productcatalog.InAdvancePaymentTerm, + }), + name: "flat-fee-lineage", + managedBy: billing.ManuallyManagedLine, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + chargeID, err := res[0].GetChargeID() + s.NoError(err) + + charge, err := s.mustGetChargeByID(chargeID).AsFlatFeeCharge() + s.NoError(err) + s.Len(charge.State.CreditRealizations, 2) + + lineages := s.mustListLineages(ns, realizationIDs(charge.State.CreditRealizations)) + s.Require().Len(lineages, 2) + + s.assertInitialLineage(lineages[charge.State.CreditRealizations[0].ID], chargeID.ID, charge.State.CreditRealizations[0].Amount, creditrealization.LineageOriginKindRealCredit, creditrealization.LineageSegmentStateRealCredit) + s.assertInitialLineage(lineages[charge.State.CreditRealizations[1].ID], chargeID.ID, charge.State.CreditRealizations[1].Amount, creditrealization.LineageOriginKindAdvance, creditrealization.LineageSegmentStateAdvanceUncovered) +} + +func (s *CreditRealizationLineageTestSuite) TestUsageBasedCreditOnlyAllocationCreatesInitialLineage() { + defer s.UsageBasedTestHandler.Reset() + + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-service-usagebased-credit-realization-lineage") + customInvoicing := s.SetupCustomInvoicing(ns) + cust := s.CreateTestCustomer(ns, "test-subject") + s.NotEmpty(cust.ID) + + _ = s.ProvisionBillingProfile(ctx, ns, customInvoicing.App.GetID(), + billingtest.WithProgressiveBilling(), + billingtest.WithCollectionInterval(datetime.MustParseDuration(s.T(), "P2D")), + billingtest.WithManualApproval(), + ) + + createAt := datetime.MustParseTimeInLocation(s.T(), "2025-12-01T00:00:00Z", time.UTC).AsTime() + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + firstCollectionAdvanceAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-01T12:00:00Z", time.UTC).AsTime() + + apiRequestsTotal := s.SetupApiRequestsTotalFeature(ctx, ns) + meterSlug := apiRequestsTotal.Feature.Key + + s.UsageBasedTestHandler.onCreditsOnlyUsageAccrued = func(ctx context.Context, input usagebased.CreditsOnlyUsageAccruedInput) (creditrealization.CreateAllocationInputs, error) { + return creditrealization.CreateAllocationInputs{ + { + ServicePeriod: input.Charge.Intent.ServicePeriod, + Amount: input.AmountToAllocate, + Annotations: creditrealization.LineageAnnotations(creditrealization.LineageOriginKindAdvance), + LedgerTransaction: ledgertransaction.GroupReference{ + TransactionGroupID: ulid.Make().String(), + }, + }, + }, nil + } + + clock.FreezeTime(createAt) + defer clock.UnFreeze() + + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: []charges.ChargeIntent{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: currencyx.Code(currency.USD), + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Amount: alpacadecimal.NewFromInt(1), + }), + name: "usage-based-lineage", + managedBy: billing.ManuallyManagedLine, + featureKey: meterSlug, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + usageCharge, err := res[0].AsUsageBasedCharge() + s.NoError(err) + + clock.FreezeTime(firstCollectionAdvanceAt) + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 3, + datetime.MustParseTimeInLocation(s.T(), "2026-01-15T00:00:00Z", time.UTC).AsTime(), + ) + + advancedCharge := s.mustAdvanceSingleUsageBasedCharge(ctx, cust.GetID()) + s.Require().NotNil(advancedCharge) + + charge, err := s.mustGetChargeByID(usageCharge.GetChargeID()).AsUsageBasedCharge() + s.NoError(err) + s.Require().NotNil(charge.State.CurrentRealizationRunID) + + currentRun, err := charge.Realizations.GetByID(*charge.State.CurrentRealizationRunID) + s.NoError(err) + s.Len(currentRun.CreditsAllocated, 1) + + lineages := s.mustListLineages(ns, realizationIDs(currentRun.CreditsAllocated)) + s.Require().Len(lineages, 1) + + s.assertInitialLineage(lineages[currentRun.CreditsAllocated[0].ID], usageCharge.ID, currentRun.CreditsAllocated[0].Amount, creditrealization.LineageOriginKindAdvance, creditrealization.LineageSegmentStateAdvanceUncovered) +} + +func (s *CreditRealizationLineageTestSuite) TestLockAdvanceLineagesForBackfillRequiresTransaction() { + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-service-lineage-lock-tx") + adapter, err := lineageadapter.New(lineageadapter.Config{ + Client: s.DBClient, + }) + s.Require().NoError(err) + + _, err = adapter.LockAdvanceLineagesForBackfill(ctx, ns, "customer-id", currencyx.Code(currency.USD)) + s.Error(err) + s.ErrorContains(err, "must be called in a transaction") +} + +func (s *CreditRealizationLineageTestSuite) TestLockAdvanceLineagesForBackfillWorksInTransaction() { + ctx, rawConfig, eDriver, err := s.DBClient.HijackTx(context.Background(), &sql.TxOptions{ReadOnly: false}) + s.Require().NoError(err) + + tx := entutils.NewTxDriver(eDriver, rawConfig) + ctx, err = transaction.SetDriverOnContext(ctx, tx) + s.Require().NoError(err) + s.T().Cleanup(func() { + _ = tx.Rollback() + }) + + ns := s.GetUniqueNamespace("charges-service-lineage-lock-in-tx") + adapter, err := lineageadapter.New(lineageadapter.Config{ + Client: s.DBClient, + }) + s.Require().NoError(err) + + lineages, err := adapter.LockAdvanceLineagesForBackfill(ctx, ns, "customer-id", currencyx.Code(currency.USD)) + s.NoError(err) + s.Empty(lineages) +} + +func (s *CreditRealizationLineageTestSuite) TestPersistCorrectionLineageSegmentsConsumesBackfilledBeforeUncovered() { + ctx := context.Background() + adapter, err := lineageadapter.New(lineageadapter.Config{ + Client: s.DBClient, + }) + s.Require().NoError(err) + + service, err := lineageservice.New(lineageservice.Config{ + Adapter: adapter, + }) + s.Require().NoError(err) + + ns := s.GetUniqueNamespace("charges-service-lineage-correction-persist") + backingTransactionGroupID := ulid.Make().String() + lineageID := ulid.Make().String() + chargeID := ulid.Make().String() + rootRealizationID := ulid.Make().String() + + _, err = s.DBClient.Charge.Create(). + SetID(chargeID). + SetNamespace(ns). + SetType(chargesmeta.ChargeTypeFlatFee). + Save(ctx) + s.Require().NoError(err) + + _, err = s.DBClient.CreditRealizationLineage.Create(). + SetID(lineageID). + SetNamespace(ns). + SetChargeID(chargeID). + SetRootRealizationID(rootRealizationID). + SetCustomerID(ulid.Make().String()). + SetCurrency(currencyx.Code(currency.USD)). + SetOriginKind(creditrealization.LineageOriginKindAdvance). + Save(ctx) + s.Require().NoError(err) + + _, err = s.DBClient.CreditRealizationLineageSegment.CreateBulk( + s.DBClient.CreditRealizationLineageSegment.Create(). + SetID(ulid.Make().String()). + SetLineageID(lineageID). + SetAmount(alpacadecimal.NewFromInt(20)). + SetState(creditrealization.LineageSegmentStateAdvanceBackfilled). + SetBackingTransactionGroupID(backingTransactionGroupID), + s.DBClient.CreditRealizationLineageSegment.Create(). + SetID(ulid.Make().String()). + SetLineageID(lineageID). + SetAmount(alpacadecimal.NewFromInt(30)). + SetState(creditrealization.LineageSegmentStateAdvanceUncovered), + ).Save(ctx) + s.Require().NoError(err) + + err = service.PersistCorrectionLineageSegments(ctx, lineage.PersistCorrectionLineageSegmentsInput{ + Namespace: ns, + Realizations: creditrealization.Realizations{ + { + CreateInput: creditrealization.CreateInput{ + Type: creditrealization.TypeCorrection, + Amount: alpacadecimal.NewFromInt(-15), + CorrectsRealizationID: lo.ToPtr(rootRealizationID), + }, + }, + }, + }) + s.Require().NoError(err) + + activeSegments, err := s.DBClient.CreditRealizationLineageSegment.Query(). + Where( + creditrealizationlineagesegment.LineageIDEQ(lineageID), + creditrealizationlineagesegment.ClosedAtIsNil(), + ). + Order(creditrealizationlineagesegment.ByCreatedAt()). + All(ctx) + s.Require().NoError(err) + s.Require().Len(activeSegments, 2) + + s.Equal(creditrealization.LineageSegmentStateAdvanceUncovered, activeSegments[0].State) + s.Equal(alpacadecimal.NewFromInt(30), activeSegments[0].Amount) + s.Nil(activeSegments[0].BackingTransactionGroupID) + + s.Equal(creditrealization.LineageSegmentStateAdvanceBackfilled, activeSegments[1].State) + s.Equal(alpacadecimal.NewFromInt(5), activeSegments[1].Amount) + s.Equal(backingTransactionGroupID, lo.FromPtr(activeSegments[1].BackingTransactionGroupID)) +} + +func (s *CreditRealizationLineageTestSuite) TestCreateSegmentRejectsInvalidInput() { + ctx := context.Background() + adapter, err := lineageadapter.New(lineageadapter.Config{ + Client: s.DBClient, + }) + s.Require().NoError(err) + + err = adapter.CreateSegment(ctx, lineage.CreateSegmentInput{ + LineageID: ulid.Make().String(), + Amount: alpacadecimal.NewFromInt(10), + State: creditrealization.LineageSegmentStateAdvanceBackfilled, + }) + s.Error(err) + s.ErrorContains(err, "backing transaction group id is required") +} + +func (s *CreditRealizationLineageTestSuite) mustAdvanceSingleUsageBasedCharge(ctx context.Context, customerID customer.CustomerID) *usagebased.Charge { + s.T().Helper() + + advancedCharges, err := s.Charges.AdvanceCharges(ctx, charges.AdvanceChargesInput{ + Customer: customerID, + }) + s.NoError(err) + + if len(advancedCharges) == 0 { + return nil + } + + s.Require().Len(advancedCharges, 1) + charge, err := advancedCharges[0].AsUsageBasedCharge() + s.NoError(err) + + return &charge +} + +func (s *CreditRealizationLineageTestSuite) mustListLineages(namespace string, realizationIDs []string) map[string]*entdb.CreditRealizationLineage { + s.T().Helper() + + lineages, err := s.DBClient.CreditRealizationLineage.Query(). + Where( + creditrealizationlineage.Namespace(namespace), + creditrealizationlineage.RootRealizationIDIn(realizationIDs...), + ). + WithSegments(). + All(s.T().Context()) + s.NoError(err) + + out := make(map[string]*entdb.CreditRealizationLineage, len(lineages)) + for _, lineage := range lineages { + out[lineage.RootRealizationID] = lineage + } + + return out +} + +func (s *CreditRealizationLineageTestSuite) assertInitialLineage(lineage *entdb.CreditRealizationLineage, chargeID string, amount alpacadecimal.Decimal, originKind creditrealization.LineageOriginKind, state creditrealization.LineageSegmentState) { + s.T().Helper() + + require.NotNil(s.T(), lineage) + s.Equal(chargeID, lineage.ChargeID) + s.Equal(originKind, lineage.OriginKind) + s.Require().Len(lineage.Edges.Segments, 1) + s.Equal(amount, lineage.Edges.Segments[0].Amount) + s.Equal(state, lineage.Edges.Segments[0].State) + s.Nil(lineage.Edges.Segments[0].ClosedAt) + s.Nil(lineage.Edges.Segments[0].BackingTransactionGroupID) +} + +func realizationIDs(realizations creditrealization.Realizations) []string { + return lo.Map(realizations, func(realization creditrealization.Realization, _ int) string { + return realization.ID + }) +} diff --git a/openmeter/billing/charges/testutils/service.go b/openmeter/billing/charges/testutils/service.go index cb38b82fec..b2c449c6dc 100644 --- a/openmeter/billing/charges/testutils/service.go +++ b/openmeter/billing/charges/testutils/service.go @@ -17,6 +17,8 @@ import ( flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" flatfeelineengine "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/lineengine" flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" chargesservice "github.com/openmeterio/openmeter/openmeter/billing/charges/service" "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" @@ -111,6 +113,20 @@ func NewServices(t testing.TB, config Config) (*Services, error) { return nil, fmt.Errorf("creating locker: %w", err) } + lineageAdapter, err := lineageadapter.New(lineageadapter.Config{ + Client: config.Client, + }) + if err != nil { + return nil, fmt.Errorf("creating lineage adapter: %w", err) + } + + lineageService, err := lineageservice.New(lineageservice.Config{ + Adapter: lineageAdapter, + }) + if err != nil { + return nil, fmt.Errorf("creating lineage service: %w", err) + } + flatFeeAdapter, err := flatfeeadapter.New(flatfeeadapter.Config{ Client: config.Client, Logger: logger, @@ -123,6 +139,7 @@ func NewServices(t testing.TB, config Config) (*Services, error) { flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ Adapter: flatFeeAdapter, Handler: config.FlatFeeHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, Locker: locker, }) @@ -154,6 +171,7 @@ func NewServices(t testing.TB, config Config) (*Services, error) { usageBasedService, err := usagebasedservice.New(usagebasedservice.Config{ Adapter: usageBasedAdapter, Handler: config.UsageBasedHandler, + Lineage: lineageService, Locker: locker, MetaAdapter: metaAdapter, CustomerOverrideService: config.BillingService, @@ -177,6 +195,7 @@ func NewServices(t testing.TB, config Config) (*Services, error) { creditPurchaseService, err := creditpurchaseservice.New(creditpurchaseservice.Config{ Adapter: creditPurchaseAdapter, Handler: config.CreditPurchaseHandler, + Lineage: lineageService, MetaAdapter: metaAdapter, }) if err != nil { diff --git a/openmeter/billing/charges/usagebased/adapter/charge.go b/openmeter/billing/charges/usagebased/adapter/charge.go index b8da9b1fde..af00decd5b 100644 --- a/openmeter/billing/charges/usagebased/adapter/charge.go +++ b/openmeter/billing/charges/usagebased/adapter/charge.go @@ -172,9 +172,14 @@ func (a *adapter) GetByIDs(ctx context.Context, input usagebased.GetByIDsInput) return nil, err } - return slicesx.MapWithErr(entitiesInOrder, func(entity *db.ChargeUsageBased) (usagebased.Charge, error) { + out, err := slicesx.MapWithErr(entitiesInOrder, func(entity *db.ChargeUsageBased) (usagebased.Charge, error) { return MapChargeFromDB(entity, input.Expands) }) + if err != nil { + return nil, err + } + + return out, nil }) } @@ -201,7 +206,12 @@ func (a *adapter) GetByID(ctx context.Context, input usagebased.GetByIDInput) (u return usagebased.Charge{}, fmt.Errorf("querying usage based charge [id=%s]: %w", input.ChargeID, err) } - return MapChargeFromDB(entity, input.Expands) + charge, err := MapChargeFromDB(entity, input.Expands) + if err != nil { + return usagebased.Charge{}, err + } + + return charge, nil }) } diff --git a/openmeter/billing/charges/usagebased/adapter/creditallocation.go b/openmeter/billing/charges/usagebased/adapter/creditallocation.go index 6aec361ff2..4cd2421955 100644 --- a/openmeter/billing/charges/usagebased/adapter/creditallocation.go +++ b/openmeter/billing/charges/usagebased/adapter/creditallocation.go @@ -38,8 +38,10 @@ func (a *adapter) CreateRunCreditRealization(ctx context.Context, runID usagebas return nil, err } - return lo.Map(dbEntities, func(entity *entdb.ChargeUsageBasedRunCreditAllocations, _ int) creditrealization.Realization { + realizations := lo.Map(dbEntities, func(entity *entdb.ChargeUsageBasedRunCreditAllocations, _ int) creditrealization.Realization { return creditrealization.MapFromDB(entity) - }), nil + }) + + return realizations, nil }) } diff --git a/openmeter/billing/charges/usagebased/handler.go b/openmeter/billing/charges/usagebased/handler.go index e37bd50fa2..87c0b56971 100644 --- a/openmeter/billing/charges/usagebased/handler.go +++ b/openmeter/billing/charges/usagebased/handler.go @@ -8,6 +8,7 @@ import ( "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" "github.com/openmeterio/openmeter/pkg/models" ) @@ -46,7 +47,8 @@ type CreditsOnlyUsageAccruedCorrectionInput struct { Run RealizationRun `json:"run"` AllocateAt time.Time `json:"allocateAt"` - Corrections creditrealization.CorrectionRequest `json:"corrections"` + Corrections creditrealization.CorrectionRequest `json:"corrections"` + LineageSegmentsByRealization lineage.ActiveSegmentsByRealizationID `json:"-"` } type Handler interface { diff --git a/openmeter/billing/charges/usagebased/service/creditsonly.go b/openmeter/billing/charges/usagebased/service/creditsonly.go index cf00a1073e..a166e8c8be 100644 --- a/openmeter/billing/charges/usagebased/service/creditsonly.go +++ b/openmeter/billing/charges/usagebased/service/creditsonly.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" @@ -126,12 +127,21 @@ func (s *CreditsOnlyStateMachine) DeleteCharge(ctx context.Context, policy meta. } for _, run := range s.Charge.Realizations { + realizationIDs := lo.Map(run.CreditsAllocated, func(realization creditrealization.Realization, _ int) string { + return realization.ID + }) + lineageSegmentsByRealization, err := s.Service.lineage.LoadActiveSegmentsByRealizationID(ctx, s.Charge.Namespace, realizationIDs) + if err != nil { + return fmt.Errorf("load active lineage segments for run %s: %w", run.ID.ID, err) + } + corrections, err := run.CreditsAllocated.CorrectAll(currencyCalculator, func(req creditrealization.CorrectionRequest) (creditrealization.CreateCorrectionInputs, error) { return s.Service.handler.OnCreditsOnlyUsageAccruedCorrection(ctx, usagebased.CreditsOnlyUsageAccruedCorrectionInput{ - Charge: s.Charge, - Run: run, - AllocateAt: clock.Now(), - Corrections: req, + Charge: s.Charge, + Run: run, + AllocateAt: clock.Now(), + Corrections: req, + LineageSegmentsByRealization: lineageSegmentsByRealization, }) }) if err != nil { @@ -139,7 +149,7 @@ func (s *CreditsOnlyStateMachine) DeleteCharge(ctx context.Context, policy meta. } if len(corrections) > 0 { - if _, err := s.Adapter.CreateRunCreditRealization(ctx, run.ID, corrections); err != nil { + if _, err := s.Service.createRunCreditRealizations(ctx, s.Charge, run.ID, corrections); err != nil { return fmt.Errorf("create credit corrections for run %s: %w", run.ID.ID, err) } } @@ -237,7 +247,7 @@ func (s *CreditsOnlyStateMachine) StartFinalRealizationRun(ctx context.Context) } if len(creditAllocations) > 0 { - creditRealizations, err = s.Adapter.CreateRunCreditRealization(ctx, currentRun.ID, creditAllocations) + creditRealizations, err = s.Service.createRunCreditRealizations(ctx, updatedCharge, currentRun.ID, creditAllocations) if err != nil { return fmt.Errorf("create credit allocations: %w", err) } @@ -312,20 +322,29 @@ func (s *CreditsOnlyStateMachine) FinalizeRealizationRun(ctx context.Context) er } if len(creditAllocations) > 0 { - if _, err := s.Adapter.CreateRunCreditRealization(ctx, currentRun.ID, creditAllocations); err != nil { + if _, err := s.Service.createRunCreditRealizations(ctx, s.Charge, currentRun.ID, creditAllocations); err != nil { return fmt.Errorf("create credit allocations: %w", err) } } case additionalAmount.IsNegative(): + realizationIDs := lo.Map(currentRun.CreditsAllocated, func(realization creditrealization.Realization, _ int) string { + return realization.ID + }) + lineageSegmentsByRealization, err := s.Service.lineage.LoadActiveSegmentsByRealizationID(ctx, s.Charge.Namespace, realizationIDs) + if err != nil { + return fmt.Errorf("load active lineage segments for current run: %w", err) + } + corrections, err := currentRun.CreditsAllocated.Correct( additionalAmount, s.CurrencyCalculator, func(req creditrealization.CorrectionRequest) (creditrealization.CreateCorrectionInputs, error) { return s.Service.handler.OnCreditsOnlyUsageAccruedCorrection(ctx, usagebased.CreditsOnlyUsageAccruedCorrectionInput{ - Charge: s.Charge, - Run: currentRun, - AllocateAt: storedAtOffset, - Corrections: req, + Charge: s.Charge, + Run: currentRun, + AllocateAt: storedAtOffset, + Corrections: req, + LineageSegmentsByRealization: lineageSegmentsByRealization, }) }, ) @@ -334,7 +353,7 @@ func (s *CreditsOnlyStateMachine) FinalizeRealizationRun(ctx context.Context) er } if len(corrections) > 0 { - if _, err := s.Adapter.CreateRunCreditRealization(ctx, currentRun.ID, corrections); err != nil { + if _, err := s.Service.createRunCreditRealizations(ctx, s.Charge, currentRun.ID, corrections); err != nil { return fmt.Errorf("create credit corrections: %w", err) } } diff --git a/openmeter/billing/charges/usagebased/service/lineage.go b/openmeter/billing/charges/usagebased/service/lineage.go new file mode 100644 index 0000000000..2e5222ba20 --- /dev/null +++ b/openmeter/billing/charges/usagebased/service/lineage.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" +) + +func (s *service) createRunCreditRealizations(ctx context.Context, charge usagebased.Charge, runID usagebased.RealizationRunID, creditAllocations creditrealization.CreateInputs) (creditrealization.Realizations, error) { + realizations, err := s.adapter.CreateRunCreditRealization(ctx, runID, creditAllocations) + if err != nil { + return nil, err + } + + if err := s.lineage.CreateInitialLineages(ctx, lineage.CreateInitialLineagesInput{ + Namespace: charge.Namespace, + ChargeID: charge.ID, + CustomerID: charge.Intent.CustomerID, + Currency: charge.Intent.Currency, + Realizations: realizations, + }); err != nil { + return nil, fmt.Errorf("create initial credit realization lineages: %w", err) + } + + if err := s.lineage.PersistCorrectionLineageSegments(ctx, lineage.PersistCorrectionLineageSegmentsInput{ + Namespace: charge.Namespace, + Realizations: realizations, + }); err != nil { + return nil, fmt.Errorf("persist correction lineage segments: %w", err) + } + + return realizations, nil +} diff --git a/openmeter/billing/charges/usagebased/service/service.go b/openmeter/billing/charges/usagebased/service/service.go index 584c530d81..bab4aeecd5 100644 --- a/openmeter/billing/charges/usagebased/service/service.go +++ b/openmeter/billing/charges/usagebased/service/service.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" "github.com/openmeterio/openmeter/openmeter/billing/rating" @@ -15,6 +16,7 @@ import ( type Config struct { Adapter usagebased.Adapter Handler usagebased.Handler + Lineage lineage.Service Locker *lockr.Locker MetaAdapter meta.Adapter CustomerOverrideService billing.CustomerOverrideService @@ -35,6 +37,10 @@ func (c Config) Validate() error { errs = append(errs, errors.New("handler cannot be null")) } + if c.Lineage == nil { + errs = append(errs, errors.New("lineage service cannot be null")) + } + if c.Locker == nil { errs = append(errs, errors.New("locker cannot be null")) } @@ -70,6 +76,7 @@ func New(config Config) (usagebased.Service, error) { return &service{ adapter: config.Adapter, handler: config.Handler, + lineage: config.Lineage, locker: config.Locker, metaAdapter: config.MetaAdapter, customerOverrideService: config.CustomerOverrideService, @@ -83,6 +90,7 @@ type service struct { streamingConnector streaming.Connector adapter usagebased.Adapter handler usagebased.Handler + lineage lineage.Service locker *lockr.Locker metaAdapter meta.Adapter customerOverrideService billing.CustomerOverrideService diff --git a/openmeter/ent/db/charge.go b/openmeter/ent/db/charge.go index 1b93568e27..d8f9db58f9 100644 --- a/openmeter/ent/db/charge.go +++ b/openmeter/ent/db/charge.go @@ -55,9 +55,11 @@ type ChargeEdges struct { BillingInvoiceLines []*BillingInvoiceLine `json:"billing_invoice_lines,omitempty"` // BillingSplitLineGroups holds the value of the billing_split_line_groups edge. BillingSplitLineGroups []*BillingInvoiceSplitLineGroup `json:"billing_split_line_groups,omitempty"` + // CreditRealizationLineages holds the value of the credit_realization_lineages edge. + CreditRealizationLineages []*CreditRealizationLineage `json:"credit_realization_lineages,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [5]bool + loadedTypes [6]bool } // FlatFeeOrErr returns the FlatFee value or an error if the edge @@ -111,6 +113,15 @@ func (e ChargeEdges) BillingSplitLineGroupsOrErr() ([]*BillingInvoiceSplitLineGr return nil, &NotLoadedError{edge: "billing_split_line_groups"} } +// CreditRealizationLineagesOrErr returns the CreditRealizationLineages value or an error if the edge +// was not loaded in eager-loading. +func (e ChargeEdges) CreditRealizationLineagesOrErr() ([]*CreditRealizationLineage, error) { + if e.loadedTypes[5] { + return e.CreditRealizationLineages, nil + } + return nil, &NotLoadedError{edge: "credit_realization_lineages"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*Charge) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -232,6 +243,11 @@ func (_m *Charge) QueryBillingSplitLineGroups() *BillingInvoiceSplitLineGroupQue return NewChargeClient(_m.config).QueryBillingSplitLineGroups(_m) } +// QueryCreditRealizationLineages queries the "credit_realization_lineages" edge of the Charge entity. +func (_m *Charge) QueryCreditRealizationLineages() *CreditRealizationLineageQuery { + return NewChargeClient(_m.config).QueryCreditRealizationLineages(_m) +} + // Update returns a builder for updating this Charge. // Note that you need to call Charge.Unwrap() before calling this method if this Charge // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/openmeter/ent/db/charge/charge.go b/openmeter/ent/db/charge/charge.go index 3aa41386cc..d17e43e076 100644 --- a/openmeter/ent/db/charge/charge.go +++ b/openmeter/ent/db/charge/charge.go @@ -40,6 +40,8 @@ const ( EdgeBillingInvoiceLines = "billing_invoice_lines" // EdgeBillingSplitLineGroups holds the string denoting the billing_split_line_groups edge name in mutations. EdgeBillingSplitLineGroups = "billing_split_line_groups" + // EdgeCreditRealizationLineages holds the string denoting the credit_realization_lineages edge name in mutations. + EdgeCreditRealizationLineages = "credit_realization_lineages" // Table holds the table name of the charge in the database. Table = "charges" // FlatFeeTable is the table that holds the flat_fee relation/edge. @@ -77,6 +79,13 @@ const ( BillingSplitLineGroupsInverseTable = "billing_invoice_split_line_groups" // BillingSplitLineGroupsColumn is the table column denoting the billing_split_line_groups relation/edge. BillingSplitLineGroupsColumn = "charge_id" + // CreditRealizationLineagesTable is the table that holds the credit_realization_lineages relation/edge. + CreditRealizationLineagesTable = "credit_realization_lineages" + // CreditRealizationLineagesInverseTable is the table name for the CreditRealizationLineage entity. + // It exists in this package in order to avoid circular dependency with the "creditrealizationlineage" package. + CreditRealizationLineagesInverseTable = "credit_realization_lineages" + // CreditRealizationLineagesColumn is the table column denoting the credit_realization_lineages relation/edge. + CreditRealizationLineagesColumn = "charge_id" ) // Columns holds all SQL columns for charge fields. @@ -209,6 +218,20 @@ func ByBillingSplitLineGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderO sqlgraph.OrderByNeighborTerms(s, newBillingSplitLineGroupsStep(), append([]sql.OrderTerm{term}, terms...)...) } } + +// ByCreditRealizationLineagesCount orders the results by credit_realization_lineages count. +func ByCreditRealizationLineagesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newCreditRealizationLineagesStep(), opts...) + } +} + +// ByCreditRealizationLineages orders the results by credit_realization_lineages terms. +func ByCreditRealizationLineages(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newCreditRealizationLineagesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newFlatFeeStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -244,3 +267,10 @@ func newBillingSplitLineGroupsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, BillingSplitLineGroupsTable, BillingSplitLineGroupsColumn), ) } +func newCreditRealizationLineagesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(CreditRealizationLineagesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, CreditRealizationLineagesTable, CreditRealizationLineagesColumn), + ) +} diff --git a/openmeter/ent/db/charge/where.go b/openmeter/ent/db/charge/where.go index 3cd0b12f64..bef7849baa 100644 --- a/openmeter/ent/db/charge/where.go +++ b/openmeter/ent/db/charge/where.go @@ -761,6 +761,29 @@ func HasBillingSplitLineGroupsWith(preds ...predicate.BillingInvoiceSplitLineGro }) } +// HasCreditRealizationLineages applies the HasEdge predicate on the "credit_realization_lineages" edge. +func HasCreditRealizationLineages() predicate.Charge { + return predicate.Charge(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, CreditRealizationLineagesTable, CreditRealizationLineagesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasCreditRealizationLineagesWith applies the HasEdge predicate on the "credit_realization_lineages" edge with a given conditions (other predicates). +func HasCreditRealizationLineagesWith(preds ...predicate.CreditRealizationLineage) predicate.Charge { + return predicate.Charge(func(s *sql.Selector) { + step := newCreditRealizationLineagesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Charge) predicate.Charge { return predicate.Charge(sql.AndPredicates(predicates...)) diff --git a/openmeter/ent/db/charge_create.go b/openmeter/ent/db/charge_create.go index fdffb88a25..a698ec5666 100644 --- a/openmeter/ent/db/charge_create.go +++ b/openmeter/ent/db/charge_create.go @@ -19,6 +19,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargecreditpurchase" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeflatfee" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebased" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" ) // ChargeCreate is the builder for creating a Charge entity. @@ -226,6 +227,21 @@ func (_c *ChargeCreate) AddBillingSplitLineGroups(v ...*BillingInvoiceSplitLineG return _c.AddBillingSplitLineGroupIDs(ids...) } +// AddCreditRealizationLineageIDs adds the "credit_realization_lineages" edge to the CreditRealizationLineage entity by IDs. +func (_c *ChargeCreate) AddCreditRealizationLineageIDs(ids ...string) *ChargeCreate { + _c.mutation.AddCreditRealizationLineageIDs(ids...) + return _c +} + +// AddCreditRealizationLineages adds the "credit_realization_lineages" edges to the CreditRealizationLineage entity. +func (_c *ChargeCreate) AddCreditRealizationLineages(v ...*CreditRealizationLineage) *ChargeCreate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddCreditRealizationLineageIDs(ids...) +} + // Mutation returns the ChargeMutation object of the builder. func (_c *ChargeCreate) Mutation() *ChargeMutation { return _c.mutation @@ -436,6 +452,22 @@ func (_c *ChargeCreate) createSpec() (*Charge, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.CreditRealizationLineagesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/openmeter/ent/db/charge_query.go b/openmeter/ent/db/charge_query.go index d89997b56a..0596d13840 100644 --- a/openmeter/ent/db/charge_query.go +++ b/openmeter/ent/db/charge_query.go @@ -19,22 +19,24 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargecreditpurchase" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeflatfee" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebased" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" ) // ChargeQuery is the builder for querying Charge entities. type ChargeQuery struct { config - ctx *QueryContext - order []charge.OrderOption - inters []Interceptor - predicates []predicate.Charge - withFlatFee *ChargeFlatFeeQuery - withCreditPurchase *ChargeCreditPurchaseQuery - withUsageBased *ChargeUsageBasedQuery - withBillingInvoiceLines *BillingInvoiceLineQuery - withBillingSplitLineGroups *BillingInvoiceSplitLineGroupQuery - modifiers []func(*sql.Selector) + ctx *QueryContext + order []charge.OrderOption + inters []Interceptor + predicates []predicate.Charge + withFlatFee *ChargeFlatFeeQuery + withCreditPurchase *ChargeCreditPurchaseQuery + withUsageBased *ChargeUsageBasedQuery + withBillingInvoiceLines *BillingInvoiceLineQuery + withBillingSplitLineGroups *BillingInvoiceSplitLineGroupQuery + withCreditRealizationLineages *CreditRealizationLineageQuery + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -181,6 +183,28 @@ func (_q *ChargeQuery) QueryBillingSplitLineGroups() *BillingInvoiceSplitLineGro return query } +// QueryCreditRealizationLineages chains the current query on the "credit_realization_lineages" edge. +func (_q *ChargeQuery) QueryCreditRealizationLineages() *CreditRealizationLineageQuery { + query := (&CreditRealizationLineageClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(charge.Table, charge.FieldID, selector), + sqlgraph.To(creditrealizationlineage.Table, creditrealizationlineage.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, charge.CreditRealizationLineagesTable, charge.CreditRealizationLineagesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Charge entity from the query. // Returns a *NotFoundError when no Charge was found. func (_q *ChargeQuery) First(ctx context.Context) (*Charge, error) { @@ -368,16 +392,17 @@ func (_q *ChargeQuery) Clone() *ChargeQuery { return nil } return &ChargeQuery{ - config: _q.config, - ctx: _q.ctx.Clone(), - order: append([]charge.OrderOption{}, _q.order...), - inters: append([]Interceptor{}, _q.inters...), - predicates: append([]predicate.Charge{}, _q.predicates...), - withFlatFee: _q.withFlatFee.Clone(), - withCreditPurchase: _q.withCreditPurchase.Clone(), - withUsageBased: _q.withUsageBased.Clone(), - withBillingInvoiceLines: _q.withBillingInvoiceLines.Clone(), - withBillingSplitLineGroups: _q.withBillingSplitLineGroups.Clone(), + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]charge.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.Charge{}, _q.predicates...), + withFlatFee: _q.withFlatFee.Clone(), + withCreditPurchase: _q.withCreditPurchase.Clone(), + withUsageBased: _q.withUsageBased.Clone(), + withBillingInvoiceLines: _q.withBillingInvoiceLines.Clone(), + withBillingSplitLineGroups: _q.withBillingSplitLineGroups.Clone(), + withCreditRealizationLineages: _q.withCreditRealizationLineages.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, @@ -439,6 +464,17 @@ func (_q *ChargeQuery) WithBillingSplitLineGroups(opts ...func(*BillingInvoiceSp return _q } +// WithCreditRealizationLineages tells the query-builder to eager-load the nodes that are connected to +// the "credit_realization_lineages" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChargeQuery) WithCreditRealizationLineages(opts ...func(*CreditRealizationLineageQuery)) *ChargeQuery { + query := (&CreditRealizationLineageClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withCreditRealizationLineages = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -517,12 +553,13 @@ func (_q *ChargeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Charg var ( nodes = []*Charge{} _spec = _q.querySpec() - loadedTypes = [5]bool{ + loadedTypes = [6]bool{ _q.withFlatFee != nil, _q.withCreditPurchase != nil, _q.withUsageBased != nil, _q.withBillingInvoiceLines != nil, _q.withBillingSplitLineGroups != nil, + _q.withCreditRealizationLineages != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -582,6 +619,15 @@ func (_q *ChargeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Charg return nil, err } } + if query := _q.withCreditRealizationLineages; query != nil { + if err := _q.loadCreditRealizationLineages(ctx, query, nodes, + func(n *Charge) { n.Edges.CreditRealizationLineages = []*CreditRealizationLineage{} }, + func(n *Charge, e *CreditRealizationLineage) { + n.Edges.CreditRealizationLineages = append(n.Edges.CreditRealizationLineages, e) + }); err != nil { + return nil, err + } + } return nodes, nil } @@ -748,6 +794,36 @@ func (_q *ChargeQuery) loadBillingSplitLineGroups(ctx context.Context, query *Bi } return nil } +func (_q *ChargeQuery) loadCreditRealizationLineages(ctx context.Context, query *CreditRealizationLineageQuery, nodes []*Charge, init func(*Charge), assign func(*Charge, *CreditRealizationLineage)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[string]*Charge) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(creditrealizationlineage.FieldChargeID) + } + query.Where(predicate.CreditRealizationLineage(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(charge.CreditRealizationLineagesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ChargeID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "charge_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *ChargeQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() diff --git a/openmeter/ent/db/charge_update.go b/openmeter/ent/db/charge_update.go index 600d0cbe6f..b3efd1752d 100644 --- a/openmeter/ent/db/charge_update.go +++ b/openmeter/ent/db/charge_update.go @@ -14,6 +14,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoicesplitlinegroup" "github.com/openmeterio/openmeter/openmeter/ent/db/charge" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" ) @@ -80,6 +81,21 @@ func (_u *ChargeUpdate) AddBillingSplitLineGroups(v ...*BillingInvoiceSplitLineG return _u.AddBillingSplitLineGroupIDs(ids...) } +// AddCreditRealizationLineageIDs adds the "credit_realization_lineages" edge to the CreditRealizationLineage entity by IDs. +func (_u *ChargeUpdate) AddCreditRealizationLineageIDs(ids ...string) *ChargeUpdate { + _u.mutation.AddCreditRealizationLineageIDs(ids...) + return _u +} + +// AddCreditRealizationLineages adds the "credit_realization_lineages" edges to the CreditRealizationLineage entity. +func (_u *ChargeUpdate) AddCreditRealizationLineages(v ...*CreditRealizationLineage) *ChargeUpdate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddCreditRealizationLineageIDs(ids...) +} + // Mutation returns the ChargeMutation object of the builder. func (_u *ChargeUpdate) Mutation() *ChargeMutation { return _u.mutation @@ -127,6 +143,27 @@ func (_u *ChargeUpdate) RemoveBillingSplitLineGroups(v ...*BillingInvoiceSplitLi return _u.RemoveBillingSplitLineGroupIDs(ids...) } +// ClearCreditRealizationLineages clears all "credit_realization_lineages" edges to the CreditRealizationLineage entity. +func (_u *ChargeUpdate) ClearCreditRealizationLineages() *ChargeUpdate { + _u.mutation.ClearCreditRealizationLineages() + return _u +} + +// RemoveCreditRealizationLineageIDs removes the "credit_realization_lineages" edge to CreditRealizationLineage entities by IDs. +func (_u *ChargeUpdate) RemoveCreditRealizationLineageIDs(ids ...string) *ChargeUpdate { + _u.mutation.RemoveCreditRealizationLineageIDs(ids...) + return _u +} + +// RemoveCreditRealizationLineages removes "credit_realization_lineages" edges to CreditRealizationLineage entities. +func (_u *ChargeUpdate) RemoveCreditRealizationLineages(v ...*CreditRealizationLineage) *ChargeUpdate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveCreditRealizationLineageIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *ChargeUpdate) Save(ctx context.Context) (int, error) { return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) @@ -262,6 +299,51 @@ func (_u *ChargeUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.CreditRealizationLineagesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedCreditRealizationLineagesIDs(); len(nodes) > 0 && !_u.mutation.CreditRealizationLineagesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.CreditRealizationLineagesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{charge.Label} @@ -332,6 +414,21 @@ func (_u *ChargeUpdateOne) AddBillingSplitLineGroups(v ...*BillingInvoiceSplitLi return _u.AddBillingSplitLineGroupIDs(ids...) } +// AddCreditRealizationLineageIDs adds the "credit_realization_lineages" edge to the CreditRealizationLineage entity by IDs. +func (_u *ChargeUpdateOne) AddCreditRealizationLineageIDs(ids ...string) *ChargeUpdateOne { + _u.mutation.AddCreditRealizationLineageIDs(ids...) + return _u +} + +// AddCreditRealizationLineages adds the "credit_realization_lineages" edges to the CreditRealizationLineage entity. +func (_u *ChargeUpdateOne) AddCreditRealizationLineages(v ...*CreditRealizationLineage) *ChargeUpdateOne { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddCreditRealizationLineageIDs(ids...) +} + // Mutation returns the ChargeMutation object of the builder. func (_u *ChargeUpdateOne) Mutation() *ChargeMutation { return _u.mutation @@ -379,6 +476,27 @@ func (_u *ChargeUpdateOne) RemoveBillingSplitLineGroups(v ...*BillingInvoiceSpli return _u.RemoveBillingSplitLineGroupIDs(ids...) } +// ClearCreditRealizationLineages clears all "credit_realization_lineages" edges to the CreditRealizationLineage entity. +func (_u *ChargeUpdateOne) ClearCreditRealizationLineages() *ChargeUpdateOne { + _u.mutation.ClearCreditRealizationLineages() + return _u +} + +// RemoveCreditRealizationLineageIDs removes the "credit_realization_lineages" edge to CreditRealizationLineage entities by IDs. +func (_u *ChargeUpdateOne) RemoveCreditRealizationLineageIDs(ids ...string) *ChargeUpdateOne { + _u.mutation.RemoveCreditRealizationLineageIDs(ids...) + return _u +} + +// RemoveCreditRealizationLineages removes "credit_realization_lineages" edges to CreditRealizationLineage entities. +func (_u *ChargeUpdateOne) RemoveCreditRealizationLineages(v ...*CreditRealizationLineage) *ChargeUpdateOne { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveCreditRealizationLineageIDs(ids...) +} + // Where appends a list predicates to the ChargeUpdate builder. func (_u *ChargeUpdateOne) Where(ps ...predicate.Charge) *ChargeUpdateOne { _u.mutation.Where(ps...) @@ -544,6 +662,51 @@ func (_u *ChargeUpdateOne) sqlSave(ctx context.Context) (_node *Charge, err erro } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.CreditRealizationLineagesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedCreditRealizationLineagesIDs(); len(nodes) > 0 && !_u.mutation.CreditRealizationLineagesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.CreditRealizationLineagesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: charge.CreditRealizationLineagesTable, + Columns: []string{charge.CreditRealizationLineagesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Charge{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/openmeter/ent/db/client.go b/openmeter/ent/db/client.go index ff1686d03f..58207f6390 100644 --- a/openmeter/ent/db/client.go +++ b/openmeter/ent/db/client.go @@ -53,6 +53,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruninvoicedusage" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedrunpayment" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruns" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" "github.com/openmeterio/openmeter/openmeter/ent/db/currencycostbasis" "github.com/openmeterio/openmeter/openmeter/ent/db/customcurrency" "github.com/openmeterio/openmeter/openmeter/ent/db/customer" @@ -173,6 +175,10 @@ type Client struct { ChargeUsageBasedRuns *ChargeUsageBasedRunsClient // ChargesSearchV1 is the client for interacting with the ChargesSearchV1 builders. ChargesSearchV1 *ChargesSearchV1Client + // CreditRealizationLineage is the client for interacting with the CreditRealizationLineage builders. + CreditRealizationLineage *CreditRealizationLineageClient + // CreditRealizationLineageSegment is the client for interacting with the CreditRealizationLineageSegment builders. + CreditRealizationLineageSegment *CreditRealizationLineageSegmentClient // CurrencyCostBasis is the client for interacting with the CurrencyCostBasis builders. CurrencyCostBasis *CurrencyCostBasisClient // CustomCurrency is the client for interacting with the CustomCurrency builders. @@ -289,6 +295,8 @@ func (c *Client) init() { c.ChargeUsageBasedRunPayment = NewChargeUsageBasedRunPaymentClient(c.config) c.ChargeUsageBasedRuns = NewChargeUsageBasedRunsClient(c.config) c.ChargesSearchV1 = NewChargesSearchV1Client(c.config) + c.CreditRealizationLineage = NewCreditRealizationLineageClient(c.config) + c.CreditRealizationLineageSegment = NewCreditRealizationLineageSegmentClient(c.config) c.CurrencyCostBasis = NewCurrencyCostBasisClient(c.config) c.CustomCurrency = NewCustomCurrencyClient(c.config) c.Customer = NewCustomerClient(c.config) @@ -453,6 +461,8 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { ChargeUsageBasedRunPayment: NewChargeUsageBasedRunPaymentClient(cfg), ChargeUsageBasedRuns: NewChargeUsageBasedRunsClient(cfg), ChargesSearchV1: NewChargesSearchV1Client(cfg), + CreditRealizationLineage: NewCreditRealizationLineageClient(cfg), + CreditRealizationLineageSegment: NewCreditRealizationLineageSegmentClient(cfg), CurrencyCostBasis: NewCurrencyCostBasisClient(cfg), CustomCurrency: NewCustomCurrencyClient(cfg), Customer: NewCustomerClient(cfg), @@ -544,6 +554,8 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) ChargeUsageBasedRunPayment: NewChargeUsageBasedRunPaymentClient(cfg), ChargeUsageBasedRuns: NewChargeUsageBasedRunsClient(cfg), ChargesSearchV1: NewChargesSearchV1Client(cfg), + CreditRealizationLineage: NewCreditRealizationLineageClient(cfg), + CreditRealizationLineageSegment: NewCreditRealizationLineageSegmentClient(cfg), CurrencyCostBasis: NewCurrencyCostBasisClient(cfg), CustomCurrency: NewCustomCurrencyClient(cfg), Customer: NewCustomerClient(cfg), @@ -621,16 +633,17 @@ func (c *Client) Use(hooks ...Hook) { c.ChargeFlatFeeCreditAllocations, c.ChargeFlatFeeInvoicedUsage, c.ChargeFlatFeePayment, c.ChargeUsageBased, c.ChargeUsageBasedRunCreditAllocations, c.ChargeUsageBasedRunInvoicedUsage, - c.ChargeUsageBasedRunPayment, c.ChargeUsageBasedRuns, c.CurrencyCostBasis, - c.CustomCurrency, c.Customer, c.CustomerSubjects, c.Entitlement, c.Feature, - c.Grant, c.LLMCostPrice, c.LedgerAccount, c.LedgerCustomerAccount, - c.LedgerEntry, c.LedgerSubAccount, c.LedgerSubAccountRoute, - c.LedgerTransaction, c.LedgerTransactionGroup, c.Meter, c.NotificationChannel, - c.NotificationEvent, c.NotificationEventDeliveryStatus, c.NotificationRule, - c.Plan, c.PlanAddon, c.PlanPhase, c.PlanRateCard, c.Subject, c.Subscription, - c.SubscriptionAddon, c.SubscriptionAddonQuantity, - c.SubscriptionBillingSyncState, c.SubscriptionItem, c.SubscriptionPhase, - c.TaxCode, c.UsageReset, + c.ChargeUsageBasedRunPayment, c.ChargeUsageBasedRuns, + c.CreditRealizationLineage, c.CreditRealizationLineageSegment, + c.CurrencyCostBasis, c.CustomCurrency, c.Customer, c.CustomerSubjects, + c.Entitlement, c.Feature, c.Grant, c.LLMCostPrice, c.LedgerAccount, + c.LedgerCustomerAccount, c.LedgerEntry, c.LedgerSubAccount, + c.LedgerSubAccountRoute, c.LedgerTransaction, c.LedgerTransactionGroup, + c.Meter, c.NotificationChannel, c.NotificationEvent, + c.NotificationEventDeliveryStatus, c.NotificationRule, c.Plan, c.PlanAddon, + c.PlanPhase, c.PlanRateCard, c.Subject, c.Subscription, c.SubscriptionAddon, + c.SubscriptionAddonQuantity, c.SubscriptionBillingSyncState, + c.SubscriptionItem, c.SubscriptionPhase, c.TaxCode, c.UsageReset, } { n.Use(hooks...) } @@ -656,6 +669,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) { c.ChargeFlatFeePayment, c.ChargeUsageBased, c.ChargeUsageBasedRunCreditAllocations, c.ChargeUsageBasedRunInvoicedUsage, c.ChargeUsageBasedRunPayment, c.ChargeUsageBasedRuns, c.ChargesSearchV1, + c.CreditRealizationLineage, c.CreditRealizationLineageSegment, c.CurrencyCostBasis, c.CustomCurrency, c.Customer, c.CustomerSubjects, c.Entitlement, c.Feature, c.Grant, c.LLMCostPrice, c.LedgerAccount, c.LedgerCustomerAccount, c.LedgerEntry, c.LedgerSubAccount, @@ -749,6 +763,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.ChargeUsageBasedRunPayment.mutate(ctx, m) case *ChargeUsageBasedRunsMutation: return c.ChargeUsageBasedRuns.mutate(ctx, m) + case *CreditRealizationLineageMutation: + return c.CreditRealizationLineage.mutate(ctx, m) + case *CreditRealizationLineageSegmentMutation: + return c.CreditRealizationLineageSegment.mutate(ctx, m) case *CurrencyCostBasisMutation: return c.CurrencyCostBasis.mutate(ctx, m) case *CustomCurrencyMutation: @@ -5565,6 +5583,22 @@ func (c *ChargeClient) QueryBillingSplitLineGroups(_m *Charge) *BillingInvoiceSp return query } +// QueryCreditRealizationLineages queries the credit_realization_lineages edge of a Charge. +func (c *ChargeClient) QueryCreditRealizationLineages(_m *Charge) *CreditRealizationLineageQuery { + query := (&CreditRealizationLineageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(charge.Table, charge.FieldID, id), + sqlgraph.To(creditrealizationlineage.Table, creditrealizationlineage.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, charge.CreditRealizationLineagesTable, charge.CreditRealizationLineagesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ChargeClient) Hooks() []Hook { return c.hooks.Charge @@ -7936,6 +7970,320 @@ func (c *ChargesSearchV1Client) Interceptors() []Interceptor { return c.inters.ChargesSearchV1 } +// CreditRealizationLineageClient is a client for the CreditRealizationLineage schema. +type CreditRealizationLineageClient struct { + config +} + +// NewCreditRealizationLineageClient returns a client for the CreditRealizationLineage from the given config. +func NewCreditRealizationLineageClient(c config) *CreditRealizationLineageClient { + return &CreditRealizationLineageClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `creditrealizationlineage.Hooks(f(g(h())))`. +func (c *CreditRealizationLineageClient) Use(hooks ...Hook) { + c.hooks.CreditRealizationLineage = append(c.hooks.CreditRealizationLineage, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `creditrealizationlineage.Intercept(f(g(h())))`. +func (c *CreditRealizationLineageClient) Intercept(interceptors ...Interceptor) { + c.inters.CreditRealizationLineage = append(c.inters.CreditRealizationLineage, interceptors...) +} + +// Create returns a builder for creating a CreditRealizationLineage entity. +func (c *CreditRealizationLineageClient) Create() *CreditRealizationLineageCreate { + mutation := newCreditRealizationLineageMutation(c.config, OpCreate) + return &CreditRealizationLineageCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of CreditRealizationLineage entities. +func (c *CreditRealizationLineageClient) CreateBulk(builders ...*CreditRealizationLineageCreate) *CreditRealizationLineageCreateBulk { + return &CreditRealizationLineageCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *CreditRealizationLineageClient) MapCreateBulk(slice any, setFunc func(*CreditRealizationLineageCreate, int)) *CreditRealizationLineageCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &CreditRealizationLineageCreateBulk{err: fmt.Errorf("calling to CreditRealizationLineageClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*CreditRealizationLineageCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &CreditRealizationLineageCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for CreditRealizationLineage. +func (c *CreditRealizationLineageClient) Update() *CreditRealizationLineageUpdate { + mutation := newCreditRealizationLineageMutation(c.config, OpUpdate) + return &CreditRealizationLineageUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *CreditRealizationLineageClient) UpdateOne(_m *CreditRealizationLineage) *CreditRealizationLineageUpdateOne { + mutation := newCreditRealizationLineageMutation(c.config, OpUpdateOne, withCreditRealizationLineage(_m)) + return &CreditRealizationLineageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *CreditRealizationLineageClient) UpdateOneID(id string) *CreditRealizationLineageUpdateOne { + mutation := newCreditRealizationLineageMutation(c.config, OpUpdateOne, withCreditRealizationLineageID(id)) + return &CreditRealizationLineageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for CreditRealizationLineage. +func (c *CreditRealizationLineageClient) Delete() *CreditRealizationLineageDelete { + mutation := newCreditRealizationLineageMutation(c.config, OpDelete) + return &CreditRealizationLineageDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *CreditRealizationLineageClient) DeleteOne(_m *CreditRealizationLineage) *CreditRealizationLineageDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *CreditRealizationLineageClient) DeleteOneID(id string) *CreditRealizationLineageDeleteOne { + builder := c.Delete().Where(creditrealizationlineage.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &CreditRealizationLineageDeleteOne{builder} +} + +// Query returns a query builder for CreditRealizationLineage. +func (c *CreditRealizationLineageClient) Query() *CreditRealizationLineageQuery { + return &CreditRealizationLineageQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeCreditRealizationLineage}, + inters: c.Interceptors(), + } +} + +// Get returns a CreditRealizationLineage entity by its id. +func (c *CreditRealizationLineageClient) Get(ctx context.Context, id string) (*CreditRealizationLineage, error) { + return c.Query().Where(creditrealizationlineage.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *CreditRealizationLineageClient) GetX(ctx context.Context, id string) *CreditRealizationLineage { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryCharge queries the charge edge of a CreditRealizationLineage. +func (c *CreditRealizationLineageClient) QueryCharge(_m *CreditRealizationLineage) *ChargeQuery { + query := (&ChargeClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineage.Table, creditrealizationlineage.FieldID, id), + sqlgraph.To(charge.Table, charge.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, creditrealizationlineage.ChargeTable, creditrealizationlineage.ChargeColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QuerySegments queries the segments edge of a CreditRealizationLineage. +func (c *CreditRealizationLineageClient) QuerySegments(_m *CreditRealizationLineage) *CreditRealizationLineageSegmentQuery { + query := (&CreditRealizationLineageSegmentClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineage.Table, creditrealizationlineage.FieldID, id), + sqlgraph.To(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, creditrealizationlineage.SegmentsTable, creditrealizationlineage.SegmentsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *CreditRealizationLineageClient) Hooks() []Hook { + return c.hooks.CreditRealizationLineage +} + +// Interceptors returns the client interceptors. +func (c *CreditRealizationLineageClient) Interceptors() []Interceptor { + return c.inters.CreditRealizationLineage +} + +func (c *CreditRealizationLineageClient) mutate(ctx context.Context, m *CreditRealizationLineageMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&CreditRealizationLineageCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&CreditRealizationLineageUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&CreditRealizationLineageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&CreditRealizationLineageDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("db: unknown CreditRealizationLineage mutation op: %q", m.Op()) + } +} + +// CreditRealizationLineageSegmentClient is a client for the CreditRealizationLineageSegment schema. +type CreditRealizationLineageSegmentClient struct { + config +} + +// NewCreditRealizationLineageSegmentClient returns a client for the CreditRealizationLineageSegment from the given config. +func NewCreditRealizationLineageSegmentClient(c config) *CreditRealizationLineageSegmentClient { + return &CreditRealizationLineageSegmentClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `creditrealizationlineagesegment.Hooks(f(g(h())))`. +func (c *CreditRealizationLineageSegmentClient) Use(hooks ...Hook) { + c.hooks.CreditRealizationLineageSegment = append(c.hooks.CreditRealizationLineageSegment, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `creditrealizationlineagesegment.Intercept(f(g(h())))`. +func (c *CreditRealizationLineageSegmentClient) Intercept(interceptors ...Interceptor) { + c.inters.CreditRealizationLineageSegment = append(c.inters.CreditRealizationLineageSegment, interceptors...) +} + +// Create returns a builder for creating a CreditRealizationLineageSegment entity. +func (c *CreditRealizationLineageSegmentClient) Create() *CreditRealizationLineageSegmentCreate { + mutation := newCreditRealizationLineageSegmentMutation(c.config, OpCreate) + return &CreditRealizationLineageSegmentCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of CreditRealizationLineageSegment entities. +func (c *CreditRealizationLineageSegmentClient) CreateBulk(builders ...*CreditRealizationLineageSegmentCreate) *CreditRealizationLineageSegmentCreateBulk { + return &CreditRealizationLineageSegmentCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *CreditRealizationLineageSegmentClient) MapCreateBulk(slice any, setFunc func(*CreditRealizationLineageSegmentCreate, int)) *CreditRealizationLineageSegmentCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &CreditRealizationLineageSegmentCreateBulk{err: fmt.Errorf("calling to CreditRealizationLineageSegmentClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*CreditRealizationLineageSegmentCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &CreditRealizationLineageSegmentCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for CreditRealizationLineageSegment. +func (c *CreditRealizationLineageSegmentClient) Update() *CreditRealizationLineageSegmentUpdate { + mutation := newCreditRealizationLineageSegmentMutation(c.config, OpUpdate) + return &CreditRealizationLineageSegmentUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *CreditRealizationLineageSegmentClient) UpdateOne(_m *CreditRealizationLineageSegment) *CreditRealizationLineageSegmentUpdateOne { + mutation := newCreditRealizationLineageSegmentMutation(c.config, OpUpdateOne, withCreditRealizationLineageSegment(_m)) + return &CreditRealizationLineageSegmentUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *CreditRealizationLineageSegmentClient) UpdateOneID(id string) *CreditRealizationLineageSegmentUpdateOne { + mutation := newCreditRealizationLineageSegmentMutation(c.config, OpUpdateOne, withCreditRealizationLineageSegmentID(id)) + return &CreditRealizationLineageSegmentUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for CreditRealizationLineageSegment. +func (c *CreditRealizationLineageSegmentClient) Delete() *CreditRealizationLineageSegmentDelete { + mutation := newCreditRealizationLineageSegmentMutation(c.config, OpDelete) + return &CreditRealizationLineageSegmentDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *CreditRealizationLineageSegmentClient) DeleteOne(_m *CreditRealizationLineageSegment) *CreditRealizationLineageSegmentDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *CreditRealizationLineageSegmentClient) DeleteOneID(id string) *CreditRealizationLineageSegmentDeleteOne { + builder := c.Delete().Where(creditrealizationlineagesegment.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &CreditRealizationLineageSegmentDeleteOne{builder} +} + +// Query returns a query builder for CreditRealizationLineageSegment. +func (c *CreditRealizationLineageSegmentClient) Query() *CreditRealizationLineageSegmentQuery { + return &CreditRealizationLineageSegmentQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeCreditRealizationLineageSegment}, + inters: c.Interceptors(), + } +} + +// Get returns a CreditRealizationLineageSegment entity by its id. +func (c *CreditRealizationLineageSegmentClient) Get(ctx context.Context, id string) (*CreditRealizationLineageSegment, error) { + return c.Query().Where(creditrealizationlineagesegment.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *CreditRealizationLineageSegmentClient) GetX(ctx context.Context, id string) *CreditRealizationLineageSegment { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryLineage queries the lineage edge of a CreditRealizationLineageSegment. +func (c *CreditRealizationLineageSegmentClient) QueryLineage(_m *CreditRealizationLineageSegment) *CreditRealizationLineageQuery { + query := (&CreditRealizationLineageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.FieldID, id), + sqlgraph.To(creditrealizationlineage.Table, creditrealizationlineage.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, creditrealizationlineagesegment.LineageTable, creditrealizationlineagesegment.LineageColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *CreditRealizationLineageSegmentClient) Hooks() []Hook { + return c.hooks.CreditRealizationLineageSegment +} + +// Interceptors returns the client interceptors. +func (c *CreditRealizationLineageSegmentClient) Interceptors() []Interceptor { + return c.inters.CreditRealizationLineageSegment +} + +func (c *CreditRealizationLineageSegmentClient) mutate(ctx context.Context, m *CreditRealizationLineageSegmentMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&CreditRealizationLineageSegmentCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&CreditRealizationLineageSegmentUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&CreditRealizationLineageSegmentUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&CreditRealizationLineageSegmentDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("db: unknown CreditRealizationLineageSegment mutation op: %q", m.Op()) + } +} + // CurrencyCostBasisClient is a client for the CurrencyCostBasis schema. type CurrencyCostBasisClient struct { config @@ -13846,7 +14194,8 @@ type ( ChargeFlatFeeCreditAllocations, ChargeFlatFeeInvoicedUsage, ChargeFlatFeePayment, ChargeUsageBased, ChargeUsageBasedRunCreditAllocations, ChargeUsageBasedRunInvoicedUsage, ChargeUsageBasedRunPayment, - ChargeUsageBasedRuns, CurrencyCostBasis, CustomCurrency, Customer, + ChargeUsageBasedRuns, CreditRealizationLineage, + CreditRealizationLineageSegment, CurrencyCostBasis, CustomCurrency, Customer, CustomerSubjects, Entitlement, Feature, Grant, LLMCostPrice, LedgerAccount, LedgerCustomerAccount, LedgerEntry, LedgerSubAccount, LedgerSubAccountRoute, LedgerTransaction, LedgerTransactionGroup, Meter, NotificationChannel, @@ -13870,15 +14219,15 @@ type ( ChargeFlatFeeCreditAllocations, ChargeFlatFeeInvoicedUsage, ChargeFlatFeePayment, ChargeUsageBased, ChargeUsageBasedRunCreditAllocations, ChargeUsageBasedRunInvoicedUsage, ChargeUsageBasedRunPayment, - ChargeUsageBasedRuns, ChargesSearchV1, CurrencyCostBasis, CustomCurrency, - Customer, CustomerSubjects, Entitlement, Feature, Grant, LLMCostPrice, - LedgerAccount, LedgerCustomerAccount, LedgerEntry, LedgerSubAccount, - LedgerSubAccountRoute, LedgerTransaction, LedgerTransactionGroup, Meter, - NotificationChannel, NotificationEvent, NotificationEventDeliveryStatus, - NotificationRule, Plan, PlanAddon, PlanPhase, PlanRateCard, Subject, - Subscription, SubscriptionAddon, SubscriptionAddonQuantity, - SubscriptionBillingSyncState, SubscriptionItem, SubscriptionPhase, TaxCode, - UsageReset []ent.Interceptor + ChargeUsageBasedRuns, ChargesSearchV1, CreditRealizationLineage, + CreditRealizationLineageSegment, CurrencyCostBasis, CustomCurrency, Customer, + CustomerSubjects, Entitlement, Feature, Grant, LLMCostPrice, LedgerAccount, + LedgerCustomerAccount, LedgerEntry, LedgerSubAccount, LedgerSubAccountRoute, + LedgerTransaction, LedgerTransactionGroup, Meter, NotificationChannel, + NotificationEvent, NotificationEventDeliveryStatus, NotificationRule, Plan, + PlanAddon, PlanPhase, PlanRateCard, Subject, Subscription, SubscriptionAddon, + SubscriptionAddonQuantity, SubscriptionBillingSyncState, SubscriptionItem, + SubscriptionPhase, TaxCode, UsageReset []ent.Interceptor } ) diff --git a/openmeter/ent/db/creditrealizationlineage.go b/openmeter/ent/db/creditrealizationlineage.go new file mode 100644 index 0000000000..e83ca6b8eb --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage.go @@ -0,0 +1,217 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/charge" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +// CreditRealizationLineage is the model entity for the CreditRealizationLineage schema. +type CreditRealizationLineage struct { + config `json:"-"` + // ID of the ent. + ID string `json:"id,omitempty"` + // Namespace holds the value of the "namespace" field. + Namespace string `json:"namespace,omitempty"` + // ChargeID holds the value of the "charge_id" field. + ChargeID string `json:"charge_id,omitempty"` + // RootRealizationID holds the value of the "root_realization_id" field. + RootRealizationID string `json:"root_realization_id,omitempty"` + // CustomerID holds the value of the "customer_id" field. + CustomerID string `json:"customer_id,omitempty"` + // Currency holds the value of the "currency" field. + Currency currencyx.Code `json:"currency,omitempty"` + // OriginKind holds the value of the "origin_kind" field. + OriginKind creditrealization.LineageOriginKind `json:"origin_kind,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the CreditRealizationLineageQuery when eager-loading is set. + Edges CreditRealizationLineageEdges `json:"edges"` + selectValues sql.SelectValues +} + +// CreditRealizationLineageEdges holds the relations/edges for other nodes in the graph. +type CreditRealizationLineageEdges struct { + // Charge holds the value of the charge edge. + Charge *Charge `json:"charge,omitempty"` + // Segments holds the value of the segments edge. + Segments []*CreditRealizationLineageSegment `json:"segments,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// ChargeOrErr returns the Charge value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e CreditRealizationLineageEdges) ChargeOrErr() (*Charge, error) { + if e.Charge != nil { + return e.Charge, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: charge.Label} + } + return nil, &NotLoadedError{edge: "charge"} +} + +// SegmentsOrErr returns the Segments value or an error if the edge +// was not loaded in eager-loading. +func (e CreditRealizationLineageEdges) SegmentsOrErr() ([]*CreditRealizationLineageSegment, error) { + if e.loadedTypes[1] { + return e.Segments, nil + } + return nil, &NotLoadedError{edge: "segments"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*CreditRealizationLineage) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case creditrealizationlineage.FieldID, creditrealizationlineage.FieldNamespace, creditrealizationlineage.FieldChargeID, creditrealizationlineage.FieldRootRealizationID, creditrealizationlineage.FieldCustomerID, creditrealizationlineage.FieldCurrency, creditrealizationlineage.FieldOriginKind: + values[i] = new(sql.NullString) + case creditrealizationlineage.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the CreditRealizationLineage fields. +func (_m *CreditRealizationLineage) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case creditrealizationlineage.FieldID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value.Valid { + _m.ID = value.String + } + case creditrealizationlineage.FieldNamespace: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field namespace", values[i]) + } else if value.Valid { + _m.Namespace = value.String + } + case creditrealizationlineage.FieldChargeID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field charge_id", values[i]) + } else if value.Valid { + _m.ChargeID = value.String + } + case creditrealizationlineage.FieldRootRealizationID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field root_realization_id", values[i]) + } else if value.Valid { + _m.RootRealizationID = value.String + } + case creditrealizationlineage.FieldCustomerID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field customer_id", values[i]) + } else if value.Valid { + _m.CustomerID = value.String + } + case creditrealizationlineage.FieldCurrency: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field currency", values[i]) + } else if value.Valid { + _m.Currency = currencyx.Code(value.String) + } + case creditrealizationlineage.FieldOriginKind: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field origin_kind", values[i]) + } else if value.Valid { + _m.OriginKind = creditrealization.LineageOriginKind(value.String) + } + case creditrealizationlineage.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the CreditRealizationLineage. +// This includes values selected through modifiers, order, etc. +func (_m *CreditRealizationLineage) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryCharge queries the "charge" edge of the CreditRealizationLineage entity. +func (_m *CreditRealizationLineage) QueryCharge() *ChargeQuery { + return NewCreditRealizationLineageClient(_m.config).QueryCharge(_m) +} + +// QuerySegments queries the "segments" edge of the CreditRealizationLineage entity. +func (_m *CreditRealizationLineage) QuerySegments() *CreditRealizationLineageSegmentQuery { + return NewCreditRealizationLineageClient(_m.config).QuerySegments(_m) +} + +// Update returns a builder for updating this CreditRealizationLineage. +// Note that you need to call CreditRealizationLineage.Unwrap() before calling this method if this CreditRealizationLineage +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *CreditRealizationLineage) Update() *CreditRealizationLineageUpdateOne { + return NewCreditRealizationLineageClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the CreditRealizationLineage entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *CreditRealizationLineage) Unwrap() *CreditRealizationLineage { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("db: CreditRealizationLineage is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *CreditRealizationLineage) String() string { + var builder strings.Builder + builder.WriteString("CreditRealizationLineage(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("namespace=") + builder.WriteString(_m.Namespace) + builder.WriteString(", ") + builder.WriteString("charge_id=") + builder.WriteString(_m.ChargeID) + builder.WriteString(", ") + builder.WriteString("root_realization_id=") + builder.WriteString(_m.RootRealizationID) + builder.WriteString(", ") + builder.WriteString("customer_id=") + builder.WriteString(_m.CustomerID) + builder.WriteString(", ") + builder.WriteString("currency=") + builder.WriteString(fmt.Sprintf("%v", _m.Currency)) + builder.WriteString(", ") + builder.WriteString("origin_kind=") + builder.WriteString(fmt.Sprintf("%v", _m.OriginKind)) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// CreditRealizationLineages is a parsable slice of CreditRealizationLineage. +type CreditRealizationLineages []*CreditRealizationLineage diff --git a/openmeter/ent/db/creditrealizationlineage/creditrealizationlineage.go b/openmeter/ent/db/creditrealizationlineage/creditrealizationlineage.go new file mode 100644 index 0000000000..cf273afc12 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage/creditrealizationlineage.go @@ -0,0 +1,180 @@ +// Code generated by ent, DO NOT EDIT. + +package creditrealizationlineage + +import ( + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" +) + +const ( + // Label holds the string label denoting the creditrealizationlineage type in the database. + Label = "credit_realization_lineage" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldNamespace holds the string denoting the namespace field in the database. + FieldNamespace = "namespace" + // FieldChargeID holds the string denoting the charge_id field in the database. + FieldChargeID = "charge_id" + // FieldRootRealizationID holds the string denoting the root_realization_id field in the database. + FieldRootRealizationID = "root_realization_id" + // FieldCustomerID holds the string denoting the customer_id field in the database. + FieldCustomerID = "customer_id" + // FieldCurrency holds the string denoting the currency field in the database. + FieldCurrency = "currency" + // FieldOriginKind holds the string denoting the origin_kind field in the database. + FieldOriginKind = "origin_kind" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // EdgeCharge holds the string denoting the charge edge name in mutations. + EdgeCharge = "charge" + // EdgeSegments holds the string denoting the segments edge name in mutations. + EdgeSegments = "segments" + // Table holds the table name of the creditrealizationlineage in the database. + Table = "credit_realization_lineages" + // ChargeTable is the table that holds the charge relation/edge. + ChargeTable = "credit_realization_lineages" + // ChargeInverseTable is the table name for the Charge entity. + // It exists in this package in order to avoid circular dependency with the "charge" package. + ChargeInverseTable = "charges" + // ChargeColumn is the table column denoting the charge relation/edge. + ChargeColumn = "charge_id" + // SegmentsTable is the table that holds the segments relation/edge. + SegmentsTable = "credit_realization_lineage_segments" + // SegmentsInverseTable is the table name for the CreditRealizationLineageSegment entity. + // It exists in this package in order to avoid circular dependency with the "creditrealizationlineagesegment" package. + SegmentsInverseTable = "credit_realization_lineage_segments" + // SegmentsColumn is the table column denoting the segments relation/edge. + SegmentsColumn = "lineage_id" +) + +// Columns holds all SQL columns for creditrealizationlineage fields. +var Columns = []string{ + FieldID, + FieldNamespace, + FieldChargeID, + FieldRootRealizationID, + FieldCustomerID, + FieldCurrency, + FieldOriginKind, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // NamespaceValidator is a validator for the "namespace" field. It is called by the builders before save. + NamespaceValidator func(string) error + // ChargeIDValidator is a validator for the "charge_id" field. It is called by the builders before save. + ChargeIDValidator func(string) error + // RootRealizationIDValidator is a validator for the "root_realization_id" field. It is called by the builders before save. + RootRealizationIDValidator func(string) error + // CustomerIDValidator is a validator for the "customer_id" field. It is called by the builders before save. + CustomerIDValidator func(string) error + // CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. + CurrencyValidator func(string) error + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultID holds the default value on creation for the "id" field. + DefaultID func() string +) + +// OriginKindValidator is a validator for the "origin_kind" field enum values. It is called by the builders before save. +func OriginKindValidator(ok creditrealization.LineageOriginKind) error { + switch ok { + case "real_credit", "advance": + return nil + default: + return fmt.Errorf("creditrealizationlineage: invalid enum value for origin_kind field: %q", ok) + } +} + +// OrderOption defines the ordering options for the CreditRealizationLineage queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByNamespace orders the results by the namespace field. +func ByNamespace(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNamespace, opts...).ToFunc() +} + +// ByChargeID orders the results by the charge_id field. +func ByChargeID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChargeID, opts...).ToFunc() +} + +// ByRootRealizationID orders the results by the root_realization_id field. +func ByRootRealizationID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRootRealizationID, opts...).ToFunc() +} + +// ByCustomerID orders the results by the customer_id field. +func ByCustomerID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCustomerID, opts...).ToFunc() +} + +// ByCurrency orders the results by the currency field. +func ByCurrency(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCurrency, opts...).ToFunc() +} + +// ByOriginKind orders the results by the origin_kind field. +func ByOriginKind(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOriginKind, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByChargeField orders the results by charge field. +func ByChargeField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChargeStep(), sql.OrderByField(field, opts...)) + } +} + +// BySegmentsCount orders the results by segments count. +func BySegmentsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSegmentsStep(), opts...) + } +} + +// BySegments orders the results by segments terms. +func BySegments(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSegmentsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newChargeStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChargeInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ChargeTable, ChargeColumn), + ) +} +func newSegmentsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SegmentsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SegmentsTable, SegmentsColumn), + ) +} diff --git a/openmeter/ent/db/creditrealizationlineage/where.go b/openmeter/ent/db/creditrealizationlineage/where.go new file mode 100644 index 0000000000..a2431fc70d --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage/where.go @@ -0,0 +1,574 @@ +// Code generated by ent, DO NOT EDIT. + +package creditrealizationlineage + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +// ID filters vertices based on their ID field. +func ID(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldID, id)) +} + +// IDEqualFold applies the EqualFold predicate on the ID field. +func IDEqualFold(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldID, id)) +} + +// IDContainsFold applies the ContainsFold predicate on the ID field. +func IDContainsFold(id string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldID, id)) +} + +// Namespace applies equality check predicate on the "namespace" field. It's identical to NamespaceEQ. +func Namespace(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldNamespace, v)) +} + +// ChargeID applies equality check predicate on the "charge_id" field. It's identical to ChargeIDEQ. +func ChargeID(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldChargeID, v)) +} + +// RootRealizationID applies equality check predicate on the "root_realization_id" field. It's identical to RootRealizationIDEQ. +func RootRealizationID(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldRootRealizationID, v)) +} + +// CustomerID applies equality check predicate on the "customer_id" field. It's identical to CustomerIDEQ. +func CustomerID(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCustomerID, v)) +} + +// Currency applies equality check predicate on the "currency" field. It's identical to CurrencyEQ. +func Currency(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCurrency, vc)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCreatedAt, v)) +} + +// NamespaceEQ applies the EQ predicate on the "namespace" field. +func NamespaceEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldNamespace, v)) +} + +// NamespaceNEQ applies the NEQ predicate on the "namespace" field. +func NamespaceNEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldNamespace, v)) +} + +// NamespaceIn applies the In predicate on the "namespace" field. +func NamespaceIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldNamespace, vs...)) +} + +// NamespaceNotIn applies the NotIn predicate on the "namespace" field. +func NamespaceNotIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldNamespace, vs...)) +} + +// NamespaceGT applies the GT predicate on the "namespace" field. +func NamespaceGT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldNamespace, v)) +} + +// NamespaceGTE applies the GTE predicate on the "namespace" field. +func NamespaceGTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldNamespace, v)) +} + +// NamespaceLT applies the LT predicate on the "namespace" field. +func NamespaceLT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldNamespace, v)) +} + +// NamespaceLTE applies the LTE predicate on the "namespace" field. +func NamespaceLTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldNamespace, v)) +} + +// NamespaceContains applies the Contains predicate on the "namespace" field. +func NamespaceContains(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContains(FieldNamespace, v)) +} + +// NamespaceHasPrefix applies the HasPrefix predicate on the "namespace" field. +func NamespaceHasPrefix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasPrefix(FieldNamespace, v)) +} + +// NamespaceHasSuffix applies the HasSuffix predicate on the "namespace" field. +func NamespaceHasSuffix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasSuffix(FieldNamespace, v)) +} + +// NamespaceEqualFold applies the EqualFold predicate on the "namespace" field. +func NamespaceEqualFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldNamespace, v)) +} + +// NamespaceContainsFold applies the ContainsFold predicate on the "namespace" field. +func NamespaceContainsFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldNamespace, v)) +} + +// ChargeIDEQ applies the EQ predicate on the "charge_id" field. +func ChargeIDEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldChargeID, v)) +} + +// ChargeIDNEQ applies the NEQ predicate on the "charge_id" field. +func ChargeIDNEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldChargeID, v)) +} + +// ChargeIDIn applies the In predicate on the "charge_id" field. +func ChargeIDIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldChargeID, vs...)) +} + +// ChargeIDNotIn applies the NotIn predicate on the "charge_id" field. +func ChargeIDNotIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldChargeID, vs...)) +} + +// ChargeIDGT applies the GT predicate on the "charge_id" field. +func ChargeIDGT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldChargeID, v)) +} + +// ChargeIDGTE applies the GTE predicate on the "charge_id" field. +func ChargeIDGTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldChargeID, v)) +} + +// ChargeIDLT applies the LT predicate on the "charge_id" field. +func ChargeIDLT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldChargeID, v)) +} + +// ChargeIDLTE applies the LTE predicate on the "charge_id" field. +func ChargeIDLTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldChargeID, v)) +} + +// ChargeIDContains applies the Contains predicate on the "charge_id" field. +func ChargeIDContains(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContains(FieldChargeID, v)) +} + +// ChargeIDHasPrefix applies the HasPrefix predicate on the "charge_id" field. +func ChargeIDHasPrefix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasPrefix(FieldChargeID, v)) +} + +// ChargeIDHasSuffix applies the HasSuffix predicate on the "charge_id" field. +func ChargeIDHasSuffix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasSuffix(FieldChargeID, v)) +} + +// ChargeIDEqualFold applies the EqualFold predicate on the "charge_id" field. +func ChargeIDEqualFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldChargeID, v)) +} + +// ChargeIDContainsFold applies the ContainsFold predicate on the "charge_id" field. +func ChargeIDContainsFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldChargeID, v)) +} + +// RootRealizationIDEQ applies the EQ predicate on the "root_realization_id" field. +func RootRealizationIDEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldRootRealizationID, v)) +} + +// RootRealizationIDNEQ applies the NEQ predicate on the "root_realization_id" field. +func RootRealizationIDNEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldRootRealizationID, v)) +} + +// RootRealizationIDIn applies the In predicate on the "root_realization_id" field. +func RootRealizationIDIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldRootRealizationID, vs...)) +} + +// RootRealizationIDNotIn applies the NotIn predicate on the "root_realization_id" field. +func RootRealizationIDNotIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldRootRealizationID, vs...)) +} + +// RootRealizationIDGT applies the GT predicate on the "root_realization_id" field. +func RootRealizationIDGT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldRootRealizationID, v)) +} + +// RootRealizationIDGTE applies the GTE predicate on the "root_realization_id" field. +func RootRealizationIDGTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldRootRealizationID, v)) +} + +// RootRealizationIDLT applies the LT predicate on the "root_realization_id" field. +func RootRealizationIDLT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldRootRealizationID, v)) +} + +// RootRealizationIDLTE applies the LTE predicate on the "root_realization_id" field. +func RootRealizationIDLTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldRootRealizationID, v)) +} + +// RootRealizationIDContains applies the Contains predicate on the "root_realization_id" field. +func RootRealizationIDContains(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContains(FieldRootRealizationID, v)) +} + +// RootRealizationIDHasPrefix applies the HasPrefix predicate on the "root_realization_id" field. +func RootRealizationIDHasPrefix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasPrefix(FieldRootRealizationID, v)) +} + +// RootRealizationIDHasSuffix applies the HasSuffix predicate on the "root_realization_id" field. +func RootRealizationIDHasSuffix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasSuffix(FieldRootRealizationID, v)) +} + +// RootRealizationIDEqualFold applies the EqualFold predicate on the "root_realization_id" field. +func RootRealizationIDEqualFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldRootRealizationID, v)) +} + +// RootRealizationIDContainsFold applies the ContainsFold predicate on the "root_realization_id" field. +func RootRealizationIDContainsFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldRootRealizationID, v)) +} + +// CustomerIDEQ applies the EQ predicate on the "customer_id" field. +func CustomerIDEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCustomerID, v)) +} + +// CustomerIDNEQ applies the NEQ predicate on the "customer_id" field. +func CustomerIDNEQ(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldCustomerID, v)) +} + +// CustomerIDIn applies the In predicate on the "customer_id" field. +func CustomerIDIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldCustomerID, vs...)) +} + +// CustomerIDNotIn applies the NotIn predicate on the "customer_id" field. +func CustomerIDNotIn(vs ...string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldCustomerID, vs...)) +} + +// CustomerIDGT applies the GT predicate on the "customer_id" field. +func CustomerIDGT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldCustomerID, v)) +} + +// CustomerIDGTE applies the GTE predicate on the "customer_id" field. +func CustomerIDGTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldCustomerID, v)) +} + +// CustomerIDLT applies the LT predicate on the "customer_id" field. +func CustomerIDLT(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldCustomerID, v)) +} + +// CustomerIDLTE applies the LTE predicate on the "customer_id" field. +func CustomerIDLTE(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldCustomerID, v)) +} + +// CustomerIDContains applies the Contains predicate on the "customer_id" field. +func CustomerIDContains(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContains(FieldCustomerID, v)) +} + +// CustomerIDHasPrefix applies the HasPrefix predicate on the "customer_id" field. +func CustomerIDHasPrefix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasPrefix(FieldCustomerID, v)) +} + +// CustomerIDHasSuffix applies the HasSuffix predicate on the "customer_id" field. +func CustomerIDHasSuffix(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldHasSuffix(FieldCustomerID, v)) +} + +// CustomerIDEqualFold applies the EqualFold predicate on the "customer_id" field. +func CustomerIDEqualFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldCustomerID, v)) +} + +// CustomerIDContainsFold applies the ContainsFold predicate on the "customer_id" field. +func CustomerIDContainsFold(v string) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldCustomerID, v)) +} + +// CurrencyEQ applies the EQ predicate on the "currency" field. +func CurrencyEQ(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCurrency, vc)) +} + +// CurrencyNEQ applies the NEQ predicate on the "currency" field. +func CurrencyNEQ(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldCurrency, vc)) +} + +// CurrencyIn applies the In predicate on the "currency" field. +func CurrencyIn(vs ...currencyx.Code) predicate.CreditRealizationLineage { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.CreditRealizationLineage(sql.FieldIn(FieldCurrency, v...)) +} + +// CurrencyNotIn applies the NotIn predicate on the "currency" field. +func CurrencyNotIn(vs ...currencyx.Code) predicate.CreditRealizationLineage { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldCurrency, v...)) +} + +// CurrencyGT applies the GT predicate on the "currency" field. +func CurrencyGT(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldGT(FieldCurrency, vc)) +} + +// CurrencyGTE applies the GTE predicate on the "currency" field. +func CurrencyGTE(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldCurrency, vc)) +} + +// CurrencyLT applies the LT predicate on the "currency" field. +func CurrencyLT(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldLT(FieldCurrency, vc)) +} + +// CurrencyLTE applies the LTE predicate on the "currency" field. +func CurrencyLTE(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldCurrency, vc)) +} + +// CurrencyContains applies the Contains predicate on the "currency" field. +func CurrencyContains(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldContains(FieldCurrency, vc)) +} + +// CurrencyHasPrefix applies the HasPrefix predicate on the "currency" field. +func CurrencyHasPrefix(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldHasPrefix(FieldCurrency, vc)) +} + +// CurrencyHasSuffix applies the HasSuffix predicate on the "currency" field. +func CurrencyHasSuffix(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldHasSuffix(FieldCurrency, vc)) +} + +// CurrencyEqualFold applies the EqualFold predicate on the "currency" field. +func CurrencyEqualFold(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldEqualFold(FieldCurrency, vc)) +} + +// CurrencyContainsFold applies the ContainsFold predicate on the "currency" field. +func CurrencyContainsFold(v currencyx.Code) predicate.CreditRealizationLineage { + vc := string(v) + return predicate.CreditRealizationLineage(sql.FieldContainsFold(FieldCurrency, vc)) +} + +// OriginKindEQ applies the EQ predicate on the "origin_kind" field. +func OriginKindEQ(v creditrealization.LineageOriginKind) predicate.CreditRealizationLineage { + vc := v + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldOriginKind, vc)) +} + +// OriginKindNEQ applies the NEQ predicate on the "origin_kind" field. +func OriginKindNEQ(v creditrealization.LineageOriginKind) predicate.CreditRealizationLineage { + vc := v + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldOriginKind, vc)) +} + +// OriginKindIn applies the In predicate on the "origin_kind" field. +func OriginKindIn(vs ...creditrealization.LineageOriginKind) predicate.CreditRealizationLineage { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.CreditRealizationLineage(sql.FieldIn(FieldOriginKind, v...)) +} + +// OriginKindNotIn applies the NotIn predicate on the "origin_kind" field. +func OriginKindNotIn(vs ...creditrealization.LineageOriginKind) predicate.CreditRealizationLineage { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldOriginKind, v...)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.FieldLTE(FieldCreatedAt, v)) +} + +// HasCharge applies the HasEdge predicate on the "charge" edge. +func HasCharge() predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ChargeTable, ChargeColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChargeWith applies the HasEdge predicate on the "charge" edge with a given conditions (other predicates). +func HasChargeWith(preds ...predicate.Charge) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(func(s *sql.Selector) { + step := newChargeStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasSegments applies the HasEdge predicate on the "segments" edge. +func HasSegments() predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SegmentsTable, SegmentsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasSegmentsWith applies the HasEdge predicate on the "segments" edge with a given conditions (other predicates). +func HasSegmentsWith(preds ...predicate.CreditRealizationLineageSegment) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(func(s *sql.Selector) { + step := newSegmentsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.CreditRealizationLineage) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.CreditRealizationLineage) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.CreditRealizationLineage) predicate.CreditRealizationLineage { + return predicate.CreditRealizationLineage(sql.NotPredicates(p)) +} diff --git a/openmeter/ent/db/creditrealizationlineage_create.go b/openmeter/ent/db/creditrealizationlineage_create.go new file mode 100644 index 0000000000..4a2d0758b9 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage_create.go @@ -0,0 +1,686 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/charge" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +// CreditRealizationLineageCreate is the builder for creating a CreditRealizationLineage entity. +type CreditRealizationLineageCreate struct { + config + mutation *CreditRealizationLineageMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetNamespace sets the "namespace" field. +func (_c *CreditRealizationLineageCreate) SetNamespace(v string) *CreditRealizationLineageCreate { + _c.mutation.SetNamespace(v) + return _c +} + +// SetChargeID sets the "charge_id" field. +func (_c *CreditRealizationLineageCreate) SetChargeID(v string) *CreditRealizationLineageCreate { + _c.mutation.SetChargeID(v) + return _c +} + +// SetRootRealizationID sets the "root_realization_id" field. +func (_c *CreditRealizationLineageCreate) SetRootRealizationID(v string) *CreditRealizationLineageCreate { + _c.mutation.SetRootRealizationID(v) + return _c +} + +// SetCustomerID sets the "customer_id" field. +func (_c *CreditRealizationLineageCreate) SetCustomerID(v string) *CreditRealizationLineageCreate { + _c.mutation.SetCustomerID(v) + return _c +} + +// SetCurrency sets the "currency" field. +func (_c *CreditRealizationLineageCreate) SetCurrency(v currencyx.Code) *CreditRealizationLineageCreate { + _c.mutation.SetCurrency(v) + return _c +} + +// SetOriginKind sets the "origin_kind" field. +func (_c *CreditRealizationLineageCreate) SetOriginKind(v creditrealization.LineageOriginKind) *CreditRealizationLineageCreate { + _c.mutation.SetOriginKind(v) + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *CreditRealizationLineageCreate) SetCreatedAt(v time.Time) *CreditRealizationLineageCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *CreditRealizationLineageCreate) SetNillableCreatedAt(v *time.Time) *CreditRealizationLineageCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetID sets the "id" field. +func (_c *CreditRealizationLineageCreate) SetID(v string) *CreditRealizationLineageCreate { + _c.mutation.SetID(v) + return _c +} + +// SetNillableID sets the "id" field if the given value is not nil. +func (_c *CreditRealizationLineageCreate) SetNillableID(v *string) *CreditRealizationLineageCreate { + if v != nil { + _c.SetID(*v) + } + return _c +} + +// SetCharge sets the "charge" edge to the Charge entity. +func (_c *CreditRealizationLineageCreate) SetCharge(v *Charge) *CreditRealizationLineageCreate { + return _c.SetChargeID(v.ID) +} + +// AddSegmentIDs adds the "segments" edge to the CreditRealizationLineageSegment entity by IDs. +func (_c *CreditRealizationLineageCreate) AddSegmentIDs(ids ...string) *CreditRealizationLineageCreate { + _c.mutation.AddSegmentIDs(ids...) + return _c +} + +// AddSegments adds the "segments" edges to the CreditRealizationLineageSegment entity. +func (_c *CreditRealizationLineageCreate) AddSegments(v ...*CreditRealizationLineageSegment) *CreditRealizationLineageCreate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddSegmentIDs(ids...) +} + +// Mutation returns the CreditRealizationLineageMutation object of the builder. +func (_c *CreditRealizationLineageCreate) Mutation() *CreditRealizationLineageMutation { + return _c.mutation +} + +// Save creates the CreditRealizationLineage in the database. +func (_c *CreditRealizationLineageCreate) Save(ctx context.Context) (*CreditRealizationLineage, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *CreditRealizationLineageCreate) SaveX(ctx context.Context) *CreditRealizationLineage { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *CreditRealizationLineageCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *CreditRealizationLineageCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *CreditRealizationLineageCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := creditrealizationlineage.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.ID(); !ok { + v := creditrealizationlineage.DefaultID() + _c.mutation.SetID(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *CreditRealizationLineageCreate) check() error { + if _, ok := _c.mutation.Namespace(); !ok { + return &ValidationError{Name: "namespace", err: errors.New(`db: missing required field "CreditRealizationLineage.namespace"`)} + } + if v, ok := _c.mutation.Namespace(); ok { + if err := creditrealizationlineage.NamespaceValidator(v); err != nil { + return &ValidationError{Name: "namespace", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.namespace": %w`, err)} + } + } + if _, ok := _c.mutation.ChargeID(); !ok { + return &ValidationError{Name: "charge_id", err: errors.New(`db: missing required field "CreditRealizationLineage.charge_id"`)} + } + if v, ok := _c.mutation.ChargeID(); ok { + if err := creditrealizationlineage.ChargeIDValidator(v); err != nil { + return &ValidationError{Name: "charge_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.charge_id": %w`, err)} + } + } + if _, ok := _c.mutation.RootRealizationID(); !ok { + return &ValidationError{Name: "root_realization_id", err: errors.New(`db: missing required field "CreditRealizationLineage.root_realization_id"`)} + } + if v, ok := _c.mutation.RootRealizationID(); ok { + if err := creditrealizationlineage.RootRealizationIDValidator(v); err != nil { + return &ValidationError{Name: "root_realization_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.root_realization_id": %w`, err)} + } + } + if _, ok := _c.mutation.CustomerID(); !ok { + return &ValidationError{Name: "customer_id", err: errors.New(`db: missing required field "CreditRealizationLineage.customer_id"`)} + } + if v, ok := _c.mutation.CustomerID(); ok { + if err := creditrealizationlineage.CustomerIDValidator(v); err != nil { + return &ValidationError{Name: "customer_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.customer_id": %w`, err)} + } + } + if _, ok := _c.mutation.Currency(); !ok { + return &ValidationError{Name: "currency", err: errors.New(`db: missing required field "CreditRealizationLineage.currency"`)} + } + if v, ok := _c.mutation.Currency(); ok { + if err := creditrealizationlineage.CurrencyValidator(string(v)); err != nil { + return &ValidationError{Name: "currency", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.currency": %w`, err)} + } + } + if _, ok := _c.mutation.OriginKind(); !ok { + return &ValidationError{Name: "origin_kind", err: errors.New(`db: missing required field "CreditRealizationLineage.origin_kind"`)} + } + if v, ok := _c.mutation.OriginKind(); ok { + if err := creditrealizationlineage.OriginKindValidator(v); err != nil { + return &ValidationError{Name: "origin_kind", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineage.origin_kind": %w`, err)} + } + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "CreditRealizationLineage.created_at"`)} + } + if len(_c.mutation.ChargeIDs()) == 0 { + return &ValidationError{Name: "charge", err: errors.New(`db: missing required edge "CreditRealizationLineage.charge"`)} + } + return nil +} + +func (_c *CreditRealizationLineageCreate) sqlSave(ctx context.Context) (*CreditRealizationLineage, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(string); ok { + _node.ID = id + } else { + return nil, fmt.Errorf("unexpected CreditRealizationLineage.ID type: %T", _spec.ID.Value) + } + } + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *CreditRealizationLineageCreate) createSpec() (*CreditRealizationLineage, *sqlgraph.CreateSpec) { + var ( + _node = &CreditRealizationLineage{config: _c.config} + _spec = sqlgraph.NewCreateSpec(creditrealizationlineage.Table, sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString)) + ) + _spec.OnConflict = _c.conflict + if id, ok := _c.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if value, ok := _c.mutation.Namespace(); ok { + _spec.SetField(creditrealizationlineage.FieldNamespace, field.TypeString, value) + _node.Namespace = value + } + if value, ok := _c.mutation.RootRealizationID(); ok { + _spec.SetField(creditrealizationlineage.FieldRootRealizationID, field.TypeString, value) + _node.RootRealizationID = value + } + if value, ok := _c.mutation.CustomerID(); ok { + _spec.SetField(creditrealizationlineage.FieldCustomerID, field.TypeString, value) + _node.CustomerID = value + } + if value, ok := _c.mutation.Currency(); ok { + _spec.SetField(creditrealizationlineage.FieldCurrency, field.TypeString, value) + _node.Currency = value + } + if value, ok := _c.mutation.OriginKind(); ok { + _spec.SetField(creditrealizationlineage.FieldOriginKind, field.TypeEnum, value) + _node.OriginKind = value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(creditrealizationlineage.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if nodes := _c.mutation.ChargeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: creditrealizationlineage.ChargeTable, + Columns: []string{creditrealizationlineage.ChargeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(charge.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.ChargeID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.SegmentsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.CreditRealizationLineage.Create(). +// SetNamespace(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.CreditRealizationLineageUpsert) { +// SetNamespace(v+v). +// }). +// Exec(ctx) +func (_c *CreditRealizationLineageCreate) OnConflict(opts ...sql.ConflictOption) *CreditRealizationLineageUpsertOne { + _c.conflict = opts + return &CreditRealizationLineageUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *CreditRealizationLineageCreate) OnConflictColumns(columns ...string) *CreditRealizationLineageUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &CreditRealizationLineageUpsertOne{ + create: _c, + } +} + +type ( + // CreditRealizationLineageUpsertOne is the builder for "upsert"-ing + // one CreditRealizationLineage node. + CreditRealizationLineageUpsertOne struct { + create *CreditRealizationLineageCreate + } + + // CreditRealizationLineageUpsert is the "OnConflict" setter. + CreditRealizationLineageUpsert struct { + *sql.UpdateSet + } +) + +// UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(creditrealizationlineage.FieldID) +// }), +// ). +// Exec(ctx) +func (u *CreditRealizationLineageUpsertOne) UpdateNewValues() *CreditRealizationLineageUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.ID(); exists { + s.SetIgnore(creditrealizationlineage.FieldID) + } + if _, exists := u.create.mutation.Namespace(); exists { + s.SetIgnore(creditrealizationlineage.FieldNamespace) + } + if _, exists := u.create.mutation.ChargeID(); exists { + s.SetIgnore(creditrealizationlineage.FieldChargeID) + } + if _, exists := u.create.mutation.RootRealizationID(); exists { + s.SetIgnore(creditrealizationlineage.FieldRootRealizationID) + } + if _, exists := u.create.mutation.CustomerID(); exists { + s.SetIgnore(creditrealizationlineage.FieldCustomerID) + } + if _, exists := u.create.mutation.Currency(); exists { + s.SetIgnore(creditrealizationlineage.FieldCurrency) + } + if _, exists := u.create.mutation.OriginKind(); exists { + s.SetIgnore(creditrealizationlineage.FieldOriginKind) + } + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(creditrealizationlineage.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *CreditRealizationLineageUpsertOne) Ignore() *CreditRealizationLineageUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *CreditRealizationLineageUpsertOne) DoNothing() *CreditRealizationLineageUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the CreditRealizationLineageCreate.OnConflict +// documentation for more info. +func (u *CreditRealizationLineageUpsertOne) Update(set func(*CreditRealizationLineageUpsert)) *CreditRealizationLineageUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&CreditRealizationLineageUpsert{UpdateSet: update}) + })) + return u +} + +// Exec executes the query. +func (u *CreditRealizationLineageUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("db: missing options for CreditRealizationLineageCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *CreditRealizationLineageUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *CreditRealizationLineageUpsertOne) ID(ctx context.Context) (id string, err error) { + if u.create.driver.Dialect() == dialect.MySQL { + // In case of "ON CONFLICT", there is no way to get back non-numeric ID + // fields from the database since MySQL does not support the RETURNING clause. + return id, errors.New("db: CreditRealizationLineageUpsertOne.ID is not supported by MySQL driver. Use CreditRealizationLineageUpsertOne.Exec instead") + } + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *CreditRealizationLineageUpsertOne) IDX(ctx context.Context) string { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// CreditRealizationLineageCreateBulk is the builder for creating many CreditRealizationLineage entities in bulk. +type CreditRealizationLineageCreateBulk struct { + config + err error + builders []*CreditRealizationLineageCreate + conflict []sql.ConflictOption +} + +// Save creates the CreditRealizationLineage entities in the database. +func (_c *CreditRealizationLineageCreateBulk) Save(ctx context.Context) ([]*CreditRealizationLineage, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*CreditRealizationLineage, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*CreditRealizationLineageMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *CreditRealizationLineageCreateBulk) SaveX(ctx context.Context) []*CreditRealizationLineage { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *CreditRealizationLineageCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *CreditRealizationLineageCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.CreditRealizationLineage.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.CreditRealizationLineageUpsert) { +// SetNamespace(v+v). +// }). +// Exec(ctx) +func (_c *CreditRealizationLineageCreateBulk) OnConflict(opts ...sql.ConflictOption) *CreditRealizationLineageUpsertBulk { + _c.conflict = opts + return &CreditRealizationLineageUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *CreditRealizationLineageCreateBulk) OnConflictColumns(columns ...string) *CreditRealizationLineageUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &CreditRealizationLineageUpsertBulk{ + create: _c, + } +} + +// CreditRealizationLineageUpsertBulk is the builder for "upsert"-ing +// a bulk of CreditRealizationLineage nodes. +type CreditRealizationLineageUpsertBulk struct { + create *CreditRealizationLineageCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(creditrealizationlineage.FieldID) +// }), +// ). +// Exec(ctx) +func (u *CreditRealizationLineageUpsertBulk) UpdateNewValues() *CreditRealizationLineageUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.ID(); exists { + s.SetIgnore(creditrealizationlineage.FieldID) + } + if _, exists := b.mutation.Namespace(); exists { + s.SetIgnore(creditrealizationlineage.FieldNamespace) + } + if _, exists := b.mutation.ChargeID(); exists { + s.SetIgnore(creditrealizationlineage.FieldChargeID) + } + if _, exists := b.mutation.RootRealizationID(); exists { + s.SetIgnore(creditrealizationlineage.FieldRootRealizationID) + } + if _, exists := b.mutation.CustomerID(); exists { + s.SetIgnore(creditrealizationlineage.FieldCustomerID) + } + if _, exists := b.mutation.Currency(); exists { + s.SetIgnore(creditrealizationlineage.FieldCurrency) + } + if _, exists := b.mutation.OriginKind(); exists { + s.SetIgnore(creditrealizationlineage.FieldOriginKind) + } + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(creditrealizationlineage.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineage.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *CreditRealizationLineageUpsertBulk) Ignore() *CreditRealizationLineageUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *CreditRealizationLineageUpsertBulk) DoNothing() *CreditRealizationLineageUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the CreditRealizationLineageCreateBulk.OnConflict +// documentation for more info. +func (u *CreditRealizationLineageUpsertBulk) Update(set func(*CreditRealizationLineageUpsert)) *CreditRealizationLineageUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&CreditRealizationLineageUpsert{UpdateSet: update}) + })) + return u +} + +// Exec executes the query. +func (u *CreditRealizationLineageUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("db: OnConflict was set for builder %d. Set it on the CreditRealizationLineageCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("db: missing options for CreditRealizationLineageCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *CreditRealizationLineageUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/openmeter/ent/db/creditrealizationlineage_delete.go b/openmeter/ent/db/creditrealizationlineage_delete.go new file mode 100644 index 0000000000..75de325fee --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageDelete is the builder for deleting a CreditRealizationLineage entity. +type CreditRealizationLineageDelete struct { + config + hooks []Hook + mutation *CreditRealizationLineageMutation +} + +// Where appends a list predicates to the CreditRealizationLineageDelete builder. +func (_d *CreditRealizationLineageDelete) Where(ps ...predicate.CreditRealizationLineage) *CreditRealizationLineageDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *CreditRealizationLineageDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *CreditRealizationLineageDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *CreditRealizationLineageDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(creditrealizationlineage.Table, sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// CreditRealizationLineageDeleteOne is the builder for deleting a single CreditRealizationLineage entity. +type CreditRealizationLineageDeleteOne struct { + _d *CreditRealizationLineageDelete +} + +// Where appends a list predicates to the CreditRealizationLineageDelete builder. +func (_d *CreditRealizationLineageDeleteOne) Where(ps ...predicate.CreditRealizationLineage) *CreditRealizationLineageDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *CreditRealizationLineageDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{creditrealizationlineage.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *CreditRealizationLineageDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/openmeter/ent/db/creditrealizationlineage_query.go b/openmeter/ent/db/creditrealizationlineage_query.go new file mode 100644 index 0000000000..42eb7270e9 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage_query.go @@ -0,0 +1,720 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/ent/db/charge" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageQuery is the builder for querying CreditRealizationLineage entities. +type CreditRealizationLineageQuery struct { + config + ctx *QueryContext + order []creditrealizationlineage.OrderOption + inters []Interceptor + predicates []predicate.CreditRealizationLineage + withCharge *ChargeQuery + withSegments *CreditRealizationLineageSegmentQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the CreditRealizationLineageQuery builder. +func (_q *CreditRealizationLineageQuery) Where(ps ...predicate.CreditRealizationLineage) *CreditRealizationLineageQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *CreditRealizationLineageQuery) Limit(limit int) *CreditRealizationLineageQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *CreditRealizationLineageQuery) Offset(offset int) *CreditRealizationLineageQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *CreditRealizationLineageQuery) Unique(unique bool) *CreditRealizationLineageQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *CreditRealizationLineageQuery) Order(o ...creditrealizationlineage.OrderOption) *CreditRealizationLineageQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryCharge chains the current query on the "charge" edge. +func (_q *CreditRealizationLineageQuery) QueryCharge() *ChargeQuery { + query := (&ChargeClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineage.Table, creditrealizationlineage.FieldID, selector), + sqlgraph.To(charge.Table, charge.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, creditrealizationlineage.ChargeTable, creditrealizationlineage.ChargeColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QuerySegments chains the current query on the "segments" edge. +func (_q *CreditRealizationLineageQuery) QuerySegments() *CreditRealizationLineageSegmentQuery { + query := (&CreditRealizationLineageSegmentClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineage.Table, creditrealizationlineage.FieldID, selector), + sqlgraph.To(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, creditrealizationlineage.SegmentsTable, creditrealizationlineage.SegmentsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first CreditRealizationLineage entity from the query. +// Returns a *NotFoundError when no CreditRealizationLineage was found. +func (_q *CreditRealizationLineageQuery) First(ctx context.Context) (*CreditRealizationLineage, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{creditrealizationlineage.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) FirstX(ctx context.Context) *CreditRealizationLineage { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first CreditRealizationLineage ID from the query. +// Returns a *NotFoundError when no CreditRealizationLineage ID was found. +func (_q *CreditRealizationLineageQuery) FirstID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{creditrealizationlineage.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) FirstIDX(ctx context.Context) string { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single CreditRealizationLineage entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one CreditRealizationLineage entity is found. +// Returns a *NotFoundError when no CreditRealizationLineage entities are found. +func (_q *CreditRealizationLineageQuery) Only(ctx context.Context) (*CreditRealizationLineage, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{creditrealizationlineage.Label} + default: + return nil, &NotSingularError{creditrealizationlineage.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) OnlyX(ctx context.Context) *CreditRealizationLineage { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only CreditRealizationLineage ID in the query. +// Returns a *NotSingularError when more than one CreditRealizationLineage ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *CreditRealizationLineageQuery) OnlyID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{creditrealizationlineage.Label} + default: + err = &NotSingularError{creditrealizationlineage.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) OnlyIDX(ctx context.Context) string { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of CreditRealizationLineages. +func (_q *CreditRealizationLineageQuery) All(ctx context.Context) ([]*CreditRealizationLineage, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*CreditRealizationLineage, *CreditRealizationLineageQuery]() + return withInterceptors[[]*CreditRealizationLineage](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) AllX(ctx context.Context) []*CreditRealizationLineage { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of CreditRealizationLineage IDs. +func (_q *CreditRealizationLineageQuery) IDs(ctx context.Context) (ids []string, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(creditrealizationlineage.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) IDsX(ctx context.Context) []string { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *CreditRealizationLineageQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*CreditRealizationLineageQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *CreditRealizationLineageQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("db: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *CreditRealizationLineageQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the CreditRealizationLineageQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *CreditRealizationLineageQuery) Clone() *CreditRealizationLineageQuery { + if _q == nil { + return nil + } + return &CreditRealizationLineageQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]creditrealizationlineage.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.CreditRealizationLineage{}, _q.predicates...), + withCharge: _q.withCharge.Clone(), + withSegments: _q.withSegments.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithCharge tells the query-builder to eager-load the nodes that are connected to +// the "charge" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *CreditRealizationLineageQuery) WithCharge(opts ...func(*ChargeQuery)) *CreditRealizationLineageQuery { + query := (&ChargeClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withCharge = query + return _q +} + +// WithSegments tells the query-builder to eager-load the nodes that are connected to +// the "segments" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *CreditRealizationLineageQuery) WithSegments(opts ...func(*CreditRealizationLineageSegmentQuery)) *CreditRealizationLineageQuery { + query := (&CreditRealizationLineageSegmentClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withSegments = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Namespace string `json:"namespace,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.CreditRealizationLineage.Query(). +// GroupBy(creditrealizationlineage.FieldNamespace). +// Aggregate(db.Count()). +// Scan(ctx, &v) +func (_q *CreditRealizationLineageQuery) GroupBy(field string, fields ...string) *CreditRealizationLineageGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &CreditRealizationLineageGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = creditrealizationlineage.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Namespace string `json:"namespace,omitempty"` +// } +// +// client.CreditRealizationLineage.Query(). +// Select(creditrealizationlineage.FieldNamespace). +// Scan(ctx, &v) +func (_q *CreditRealizationLineageQuery) Select(fields ...string) *CreditRealizationLineageSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &CreditRealizationLineageSelect{CreditRealizationLineageQuery: _q} + sbuild.label = creditrealizationlineage.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a CreditRealizationLineageSelect configured with the given aggregations. +func (_q *CreditRealizationLineageQuery) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *CreditRealizationLineageQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !creditrealizationlineage.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *CreditRealizationLineageQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*CreditRealizationLineage, error) { + var ( + nodes = []*CreditRealizationLineage{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withCharge != nil, + _q.withSegments != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*CreditRealizationLineage).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &CreditRealizationLineage{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withCharge; query != nil { + if err := _q.loadCharge(ctx, query, nodes, nil, + func(n *CreditRealizationLineage, e *Charge) { n.Edges.Charge = e }); err != nil { + return nil, err + } + } + if query := _q.withSegments; query != nil { + if err := _q.loadSegments(ctx, query, nodes, + func(n *CreditRealizationLineage) { n.Edges.Segments = []*CreditRealizationLineageSegment{} }, + func(n *CreditRealizationLineage, e *CreditRealizationLineageSegment) { + n.Edges.Segments = append(n.Edges.Segments, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *CreditRealizationLineageQuery) loadCharge(ctx context.Context, query *ChargeQuery, nodes []*CreditRealizationLineage, init func(*CreditRealizationLineage), assign func(*CreditRealizationLineage, *Charge)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*CreditRealizationLineage) + for i := range nodes { + fk := nodes[i].ChargeID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(charge.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "charge_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *CreditRealizationLineageQuery) loadSegments(ctx context.Context, query *CreditRealizationLineageSegmentQuery, nodes []*CreditRealizationLineage, init func(*CreditRealizationLineage), assign func(*CreditRealizationLineage, *CreditRealizationLineageSegment)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[string]*CreditRealizationLineage) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(creditrealizationlineagesegment.FieldLineageID) + } + query.Where(predicate.CreditRealizationLineageSegment(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(creditrealizationlineage.SegmentsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.LineageID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "lineage_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *CreditRealizationLineageQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *CreditRealizationLineageQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(creditrealizationlineage.Table, creditrealizationlineage.Columns, sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, creditrealizationlineage.FieldID) + for i := range fields { + if fields[i] != creditrealizationlineage.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withCharge != nil { + _spec.Node.AddColumnOnce(creditrealizationlineage.FieldChargeID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *CreditRealizationLineageQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(creditrealizationlineage.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = creditrealizationlineage.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *CreditRealizationLineageQuery) ForUpdate(opts ...sql.LockOption) *CreditRealizationLineageQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *CreditRealizationLineageQuery) ForShare(opts ...sql.LockOption) *CreditRealizationLineageQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// CreditRealizationLineageGroupBy is the group-by builder for CreditRealizationLineage entities. +type CreditRealizationLineageGroupBy struct { + selector + build *CreditRealizationLineageQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *CreditRealizationLineageGroupBy) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *CreditRealizationLineageGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CreditRealizationLineageQuery, *CreditRealizationLineageGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *CreditRealizationLineageGroupBy) sqlScan(ctx context.Context, root *CreditRealizationLineageQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// CreditRealizationLineageSelect is the builder for selecting fields of CreditRealizationLineage entities. +type CreditRealizationLineageSelect struct { + *CreditRealizationLineageQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *CreditRealizationLineageSelect) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *CreditRealizationLineageSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CreditRealizationLineageQuery, *CreditRealizationLineageSelect](ctx, _s.CreditRealizationLineageQuery, _s, _s.inters, v) +} + +func (_s *CreditRealizationLineageSelect) sqlScan(ctx context.Context, root *CreditRealizationLineageQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/openmeter/ent/db/creditrealizationlineage_update.go b/openmeter/ent/db/creditrealizationlineage_update.go new file mode 100644 index 0000000000..f610ad9187 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineage_update.go @@ -0,0 +1,360 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageUpdate is the builder for updating CreditRealizationLineage entities. +type CreditRealizationLineageUpdate struct { + config + hooks []Hook + mutation *CreditRealizationLineageMutation +} + +// Where appends a list predicates to the CreditRealizationLineageUpdate builder. +func (_u *CreditRealizationLineageUpdate) Where(ps ...predicate.CreditRealizationLineage) *CreditRealizationLineageUpdate { + _u.mutation.Where(ps...) + return _u +} + +// AddSegmentIDs adds the "segments" edge to the CreditRealizationLineageSegment entity by IDs. +func (_u *CreditRealizationLineageUpdate) AddSegmentIDs(ids ...string) *CreditRealizationLineageUpdate { + _u.mutation.AddSegmentIDs(ids...) + return _u +} + +// AddSegments adds the "segments" edges to the CreditRealizationLineageSegment entity. +func (_u *CreditRealizationLineageUpdate) AddSegments(v ...*CreditRealizationLineageSegment) *CreditRealizationLineageUpdate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddSegmentIDs(ids...) +} + +// Mutation returns the CreditRealizationLineageMutation object of the builder. +func (_u *CreditRealizationLineageUpdate) Mutation() *CreditRealizationLineageMutation { + return _u.mutation +} + +// ClearSegments clears all "segments" edges to the CreditRealizationLineageSegment entity. +func (_u *CreditRealizationLineageUpdate) ClearSegments() *CreditRealizationLineageUpdate { + _u.mutation.ClearSegments() + return _u +} + +// RemoveSegmentIDs removes the "segments" edge to CreditRealizationLineageSegment entities by IDs. +func (_u *CreditRealizationLineageUpdate) RemoveSegmentIDs(ids ...string) *CreditRealizationLineageUpdate { + _u.mutation.RemoveSegmentIDs(ids...) + return _u +} + +// RemoveSegments removes "segments" edges to CreditRealizationLineageSegment entities. +func (_u *CreditRealizationLineageUpdate) RemoveSegments(v ...*CreditRealizationLineageSegment) *CreditRealizationLineageUpdate { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveSegmentIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *CreditRealizationLineageUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *CreditRealizationLineageUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *CreditRealizationLineageUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *CreditRealizationLineageUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *CreditRealizationLineageUpdate) check() error { + if _u.mutation.ChargeCleared() && len(_u.mutation.ChargeIDs()) > 0 { + return errors.New(`db: clearing a required unique edge "CreditRealizationLineage.charge"`) + } + return nil +} + +func (_u *CreditRealizationLineageUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(creditrealizationlineage.Table, creditrealizationlineage.Columns, sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if _u.mutation.SegmentsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedSegmentsIDs(); len(nodes) > 0 && !_u.mutation.SegmentsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SegmentsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{creditrealizationlineage.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// CreditRealizationLineageUpdateOne is the builder for updating a single CreditRealizationLineage entity. +type CreditRealizationLineageUpdateOne struct { + config + fields []string + hooks []Hook + mutation *CreditRealizationLineageMutation +} + +// AddSegmentIDs adds the "segments" edge to the CreditRealizationLineageSegment entity by IDs. +func (_u *CreditRealizationLineageUpdateOne) AddSegmentIDs(ids ...string) *CreditRealizationLineageUpdateOne { + _u.mutation.AddSegmentIDs(ids...) + return _u +} + +// AddSegments adds the "segments" edges to the CreditRealizationLineageSegment entity. +func (_u *CreditRealizationLineageUpdateOne) AddSegments(v ...*CreditRealizationLineageSegment) *CreditRealizationLineageUpdateOne { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddSegmentIDs(ids...) +} + +// Mutation returns the CreditRealizationLineageMutation object of the builder. +func (_u *CreditRealizationLineageUpdateOne) Mutation() *CreditRealizationLineageMutation { + return _u.mutation +} + +// ClearSegments clears all "segments" edges to the CreditRealizationLineageSegment entity. +func (_u *CreditRealizationLineageUpdateOne) ClearSegments() *CreditRealizationLineageUpdateOne { + _u.mutation.ClearSegments() + return _u +} + +// RemoveSegmentIDs removes the "segments" edge to CreditRealizationLineageSegment entities by IDs. +func (_u *CreditRealizationLineageUpdateOne) RemoveSegmentIDs(ids ...string) *CreditRealizationLineageUpdateOne { + _u.mutation.RemoveSegmentIDs(ids...) + return _u +} + +// RemoveSegments removes "segments" edges to CreditRealizationLineageSegment entities. +func (_u *CreditRealizationLineageUpdateOne) RemoveSegments(v ...*CreditRealizationLineageSegment) *CreditRealizationLineageUpdateOne { + ids := make([]string, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveSegmentIDs(ids...) +} + +// Where appends a list predicates to the CreditRealizationLineageUpdate builder. +func (_u *CreditRealizationLineageUpdateOne) Where(ps ...predicate.CreditRealizationLineage) *CreditRealizationLineageUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *CreditRealizationLineageUpdateOne) Select(field string, fields ...string) *CreditRealizationLineageUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated CreditRealizationLineage entity. +func (_u *CreditRealizationLineageUpdateOne) Save(ctx context.Context) (*CreditRealizationLineage, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *CreditRealizationLineageUpdateOne) SaveX(ctx context.Context) *CreditRealizationLineage { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *CreditRealizationLineageUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *CreditRealizationLineageUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *CreditRealizationLineageUpdateOne) check() error { + if _u.mutation.ChargeCleared() && len(_u.mutation.ChargeIDs()) > 0 { + return errors.New(`db: clearing a required unique edge "CreditRealizationLineage.charge"`) + } + return nil +} + +func (_u *CreditRealizationLineageUpdateOne) sqlSave(ctx context.Context) (_node *CreditRealizationLineage, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(creditrealizationlineage.Table, creditrealizationlineage.Columns, sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "CreditRealizationLineage.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, creditrealizationlineage.FieldID) + for _, f := range fields { + if !creditrealizationlineage.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != creditrealizationlineage.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if _u.mutation.SegmentsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedSegmentsIDs(); len(nodes) > 0 && !_u.mutation.SegmentsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SegmentsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: creditrealizationlineage.SegmentsTable, + Columns: []string{creditrealizationlineage.SegmentsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &CreditRealizationLineage{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{creditrealizationlineage.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment.go b/openmeter/ent/db/creditrealizationlineagesegment.go new file mode 100644 index 0000000000..9089398f2a --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment.go @@ -0,0 +1,198 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" +) + +// CreditRealizationLineageSegment is the model entity for the CreditRealizationLineageSegment schema. +type CreditRealizationLineageSegment struct { + config `json:"-"` + // ID of the ent. + ID string `json:"id,omitempty"` + // LineageID holds the value of the "lineage_id" field. + LineageID string `json:"lineage_id,omitempty"` + // Amount holds the value of the "amount" field. + Amount alpacadecimal.Decimal `json:"amount,omitempty"` + // State holds the value of the "state" field. + State creditrealization.LineageSegmentState `json:"state,omitempty"` + // BackingTransactionGroupID holds the value of the "backing_transaction_group_id" field. + BackingTransactionGroupID *string `json:"backing_transaction_group_id,omitempty"` + // ClosedAt holds the value of the "closed_at" field. + ClosedAt *time.Time `json:"closed_at,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the CreditRealizationLineageSegmentQuery when eager-loading is set. + Edges CreditRealizationLineageSegmentEdges `json:"edges"` + selectValues sql.SelectValues +} + +// CreditRealizationLineageSegmentEdges holds the relations/edges for other nodes in the graph. +type CreditRealizationLineageSegmentEdges struct { + // Lineage holds the value of the lineage edge. + Lineage *CreditRealizationLineage `json:"lineage,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// LineageOrErr returns the Lineage value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e CreditRealizationLineageSegmentEdges) LineageOrErr() (*CreditRealizationLineage, error) { + if e.Lineage != nil { + return e.Lineage, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: creditrealizationlineage.Label} + } + return nil, &NotLoadedError{edge: "lineage"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*CreditRealizationLineageSegment) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case creditrealizationlineagesegment.FieldAmount: + values[i] = new(alpacadecimal.Decimal) + case creditrealizationlineagesegment.FieldID, creditrealizationlineagesegment.FieldLineageID, creditrealizationlineagesegment.FieldState, creditrealizationlineagesegment.FieldBackingTransactionGroupID: + values[i] = new(sql.NullString) + case creditrealizationlineagesegment.FieldClosedAt, creditrealizationlineagesegment.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the CreditRealizationLineageSegment fields. +func (_m *CreditRealizationLineageSegment) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case creditrealizationlineagesegment.FieldID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value.Valid { + _m.ID = value.String + } + case creditrealizationlineagesegment.FieldLineageID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field lineage_id", values[i]) + } else if value.Valid { + _m.LineageID = value.String + } + case creditrealizationlineagesegment.FieldAmount: + if value, ok := values[i].(*alpacadecimal.Decimal); !ok { + return fmt.Errorf("unexpected type %T for field amount", values[i]) + } else if value != nil { + _m.Amount = *value + } + case creditrealizationlineagesegment.FieldState: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field state", values[i]) + } else if value.Valid { + _m.State = creditrealization.LineageSegmentState(value.String) + } + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field backing_transaction_group_id", values[i]) + } else if value.Valid { + _m.BackingTransactionGroupID = new(string) + *_m.BackingTransactionGroupID = value.String + } + case creditrealizationlineagesegment.FieldClosedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field closed_at", values[i]) + } else if value.Valid { + _m.ClosedAt = new(time.Time) + *_m.ClosedAt = value.Time + } + case creditrealizationlineagesegment.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the CreditRealizationLineageSegment. +// This includes values selected through modifiers, order, etc. +func (_m *CreditRealizationLineageSegment) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryLineage queries the "lineage" edge of the CreditRealizationLineageSegment entity. +func (_m *CreditRealizationLineageSegment) QueryLineage() *CreditRealizationLineageQuery { + return NewCreditRealizationLineageSegmentClient(_m.config).QueryLineage(_m) +} + +// Update returns a builder for updating this CreditRealizationLineageSegment. +// Note that you need to call CreditRealizationLineageSegment.Unwrap() before calling this method if this CreditRealizationLineageSegment +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *CreditRealizationLineageSegment) Update() *CreditRealizationLineageSegmentUpdateOne { + return NewCreditRealizationLineageSegmentClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the CreditRealizationLineageSegment entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *CreditRealizationLineageSegment) Unwrap() *CreditRealizationLineageSegment { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("db: CreditRealizationLineageSegment is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *CreditRealizationLineageSegment) String() string { + var builder strings.Builder + builder.WriteString("CreditRealizationLineageSegment(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("lineage_id=") + builder.WriteString(_m.LineageID) + builder.WriteString(", ") + builder.WriteString("amount=") + builder.WriteString(fmt.Sprintf("%v", _m.Amount)) + builder.WriteString(", ") + builder.WriteString("state=") + builder.WriteString(fmt.Sprintf("%v", _m.State)) + builder.WriteString(", ") + if v := _m.BackingTransactionGroupID; v != nil { + builder.WriteString("backing_transaction_group_id=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ClosedAt; v != nil { + builder.WriteString("closed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// CreditRealizationLineageSegments is a parsable slice of CreditRealizationLineageSegment. +type CreditRealizationLineageSegments []*CreditRealizationLineageSegment diff --git a/openmeter/ent/db/creditrealizationlineagesegment/creditrealizationlineagesegment.go b/openmeter/ent/db/creditrealizationlineagesegment/creditrealizationlineagesegment.go new file mode 100644 index 0000000000..03c0bc2703 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment/creditrealizationlineagesegment.go @@ -0,0 +1,136 @@ +// Code generated by ent, DO NOT EDIT. + +package creditrealizationlineagesegment + +import ( + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" +) + +const ( + // Label holds the string label denoting the creditrealizationlineagesegment type in the database. + Label = "credit_realization_lineage_segment" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldLineageID holds the string denoting the lineage_id field in the database. + FieldLineageID = "lineage_id" + // FieldAmount holds the string denoting the amount field in the database. + FieldAmount = "amount" + // FieldState holds the string denoting the state field in the database. + FieldState = "state" + // FieldBackingTransactionGroupID holds the string denoting the backing_transaction_group_id field in the database. + FieldBackingTransactionGroupID = "backing_transaction_group_id" + // FieldClosedAt holds the string denoting the closed_at field in the database. + FieldClosedAt = "closed_at" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // EdgeLineage holds the string denoting the lineage edge name in mutations. + EdgeLineage = "lineage" + // Table holds the table name of the creditrealizationlineagesegment in the database. + Table = "credit_realization_lineage_segments" + // LineageTable is the table that holds the lineage relation/edge. + LineageTable = "credit_realization_lineage_segments" + // LineageInverseTable is the table name for the CreditRealizationLineage entity. + // It exists in this package in order to avoid circular dependency with the "creditrealizationlineage" package. + LineageInverseTable = "credit_realization_lineages" + // LineageColumn is the table column denoting the lineage relation/edge. + LineageColumn = "lineage_id" +) + +// Columns holds all SQL columns for creditrealizationlineagesegment fields. +var Columns = []string{ + FieldID, + FieldLineageID, + FieldAmount, + FieldState, + FieldBackingTransactionGroupID, + FieldClosedAt, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // LineageIDValidator is a validator for the "lineage_id" field. It is called by the builders before save. + LineageIDValidator func(string) error + // BackingTransactionGroupIDValidator is a validator for the "backing_transaction_group_id" field. It is called by the builders before save. + BackingTransactionGroupIDValidator func(string) error + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultID holds the default value on creation for the "id" field. + DefaultID func() string +) + +// StateValidator is a validator for the "state" field enum values. It is called by the builders before save. +func StateValidator(s creditrealization.LineageSegmentState) error { + switch s { + case "real_credit", "advance_uncovered", "advance_backfilled": + return nil + default: + return fmt.Errorf("creditrealizationlineagesegment: invalid enum value for state field: %q", s) + } +} + +// OrderOption defines the ordering options for the CreditRealizationLineageSegment queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByLineageID orders the results by the lineage_id field. +func ByLineageID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLineageID, opts...).ToFunc() +} + +// ByAmount orders the results by the amount field. +func ByAmount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAmount, opts...).ToFunc() +} + +// ByState orders the results by the state field. +func ByState(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldState, opts...).ToFunc() +} + +// ByBackingTransactionGroupID orders the results by the backing_transaction_group_id field. +func ByBackingTransactionGroupID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBackingTransactionGroupID, opts...).ToFunc() +} + +// ByClosedAt orders the results by the closed_at field. +func ByClosedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClosedAt, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByLineageField orders the results by lineage field. +func ByLineageField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newLineageStep(), sql.OrderByField(field, opts...)) + } +} +func newLineageStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(LineageInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, LineageTable, LineageColumn), + ) +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment/where.go b/openmeter/ent/db/creditrealizationlineagesegment/where.go new file mode 100644 index 0000000000..ad70e7c649 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment/where.go @@ -0,0 +1,431 @@ +// Code generated by ent, DO NOT EDIT. + +package creditrealizationlineagesegment + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldID, id)) +} + +// IDEqualFold applies the EqualFold predicate on the ID field. +func IDEqualFold(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEqualFold(FieldID, id)) +} + +// IDContainsFold applies the ContainsFold predicate on the ID field. +func IDContainsFold(id string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldContainsFold(FieldID, id)) +} + +// LineageID applies equality check predicate on the "lineage_id" field. It's identical to LineageIDEQ. +func LineageID(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldLineageID, v)) +} + +// Amount applies equality check predicate on the "amount" field. It's identical to AmountEQ. +func Amount(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldAmount, v)) +} + +// BackingTransactionGroupID applies equality check predicate on the "backing_transaction_group_id" field. It's identical to BackingTransactionGroupIDEQ. +func BackingTransactionGroupID(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldBackingTransactionGroupID, v)) +} + +// ClosedAt applies equality check predicate on the "closed_at" field. It's identical to ClosedAtEQ. +func ClosedAt(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldClosedAt, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldCreatedAt, v)) +} + +// LineageIDEQ applies the EQ predicate on the "lineage_id" field. +func LineageIDEQ(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldLineageID, v)) +} + +// LineageIDNEQ applies the NEQ predicate on the "lineage_id" field. +func LineageIDNEQ(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldLineageID, v)) +} + +// LineageIDIn applies the In predicate on the "lineage_id" field. +func LineageIDIn(vs ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldLineageID, vs...)) +} + +// LineageIDNotIn applies the NotIn predicate on the "lineage_id" field. +func LineageIDNotIn(vs ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldLineageID, vs...)) +} + +// LineageIDGT applies the GT predicate on the "lineage_id" field. +func LineageIDGT(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldLineageID, v)) +} + +// LineageIDGTE applies the GTE predicate on the "lineage_id" field. +func LineageIDGTE(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldLineageID, v)) +} + +// LineageIDLT applies the LT predicate on the "lineage_id" field. +func LineageIDLT(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldLineageID, v)) +} + +// LineageIDLTE applies the LTE predicate on the "lineage_id" field. +func LineageIDLTE(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldLineageID, v)) +} + +// LineageIDContains applies the Contains predicate on the "lineage_id" field. +func LineageIDContains(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldContains(FieldLineageID, v)) +} + +// LineageIDHasPrefix applies the HasPrefix predicate on the "lineage_id" field. +func LineageIDHasPrefix(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldHasPrefix(FieldLineageID, v)) +} + +// LineageIDHasSuffix applies the HasSuffix predicate on the "lineage_id" field. +func LineageIDHasSuffix(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldHasSuffix(FieldLineageID, v)) +} + +// LineageIDEqualFold applies the EqualFold predicate on the "lineage_id" field. +func LineageIDEqualFold(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEqualFold(FieldLineageID, v)) +} + +// LineageIDContainsFold applies the ContainsFold predicate on the "lineage_id" field. +func LineageIDContainsFold(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldContainsFold(FieldLineageID, v)) +} + +// AmountEQ applies the EQ predicate on the "amount" field. +func AmountEQ(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldAmount, v)) +} + +// AmountNEQ applies the NEQ predicate on the "amount" field. +func AmountNEQ(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldAmount, v)) +} + +// AmountIn applies the In predicate on the "amount" field. +func AmountIn(vs ...alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldAmount, vs...)) +} + +// AmountNotIn applies the NotIn predicate on the "amount" field. +func AmountNotIn(vs ...alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldAmount, vs...)) +} + +// AmountGT applies the GT predicate on the "amount" field. +func AmountGT(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldAmount, v)) +} + +// AmountGTE applies the GTE predicate on the "amount" field. +func AmountGTE(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldAmount, v)) +} + +// AmountLT applies the LT predicate on the "amount" field. +func AmountLT(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldAmount, v)) +} + +// AmountLTE applies the LTE predicate on the "amount" field. +func AmountLTE(v alpacadecimal.Decimal) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldAmount, v)) +} + +// StateEQ applies the EQ predicate on the "state" field. +func StateEQ(v creditrealization.LineageSegmentState) predicate.CreditRealizationLineageSegment { + vc := v + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldState, vc)) +} + +// StateNEQ applies the NEQ predicate on the "state" field. +func StateNEQ(v creditrealization.LineageSegmentState) predicate.CreditRealizationLineageSegment { + vc := v + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldState, vc)) +} + +// StateIn applies the In predicate on the "state" field. +func StateIn(vs ...creditrealization.LineageSegmentState) predicate.CreditRealizationLineageSegment { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldState, v...)) +} + +// StateNotIn applies the NotIn predicate on the "state" field. +func StateNotIn(vs ...creditrealization.LineageSegmentState) predicate.CreditRealizationLineageSegment { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldState, v...)) +} + +// BackingTransactionGroupIDEQ applies the EQ predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDEQ(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDNEQ applies the NEQ predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDNEQ(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDIn applies the In predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDIn(vs ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldBackingTransactionGroupID, vs...)) +} + +// BackingTransactionGroupIDNotIn applies the NotIn predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDNotIn(vs ...string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldBackingTransactionGroupID, vs...)) +} + +// BackingTransactionGroupIDGT applies the GT predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDGT(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDGTE applies the GTE predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDGTE(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDLT applies the LT predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDLT(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDLTE applies the LTE predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDLTE(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDContains applies the Contains predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDContains(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldContains(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDHasPrefix applies the HasPrefix predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDHasPrefix(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldHasPrefix(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDHasSuffix applies the HasSuffix predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDHasSuffix(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldHasSuffix(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDIsNil applies the IsNil predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDIsNil() predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIsNull(FieldBackingTransactionGroupID)) +} + +// BackingTransactionGroupIDNotNil applies the NotNil predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDNotNil() predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotNull(FieldBackingTransactionGroupID)) +} + +// BackingTransactionGroupIDEqualFold applies the EqualFold predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDEqualFold(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEqualFold(FieldBackingTransactionGroupID, v)) +} + +// BackingTransactionGroupIDContainsFold applies the ContainsFold predicate on the "backing_transaction_group_id" field. +func BackingTransactionGroupIDContainsFold(v string) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldContainsFold(FieldBackingTransactionGroupID, v)) +} + +// ClosedAtEQ applies the EQ predicate on the "closed_at" field. +func ClosedAtEQ(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldClosedAt, v)) +} + +// ClosedAtNEQ applies the NEQ predicate on the "closed_at" field. +func ClosedAtNEQ(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldClosedAt, v)) +} + +// ClosedAtIn applies the In predicate on the "closed_at" field. +func ClosedAtIn(vs ...time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldClosedAt, vs...)) +} + +// ClosedAtNotIn applies the NotIn predicate on the "closed_at" field. +func ClosedAtNotIn(vs ...time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldClosedAt, vs...)) +} + +// ClosedAtGT applies the GT predicate on the "closed_at" field. +func ClosedAtGT(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldClosedAt, v)) +} + +// ClosedAtGTE applies the GTE predicate on the "closed_at" field. +func ClosedAtGTE(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldClosedAt, v)) +} + +// ClosedAtLT applies the LT predicate on the "closed_at" field. +func ClosedAtLT(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldClosedAt, v)) +} + +// ClosedAtLTE applies the LTE predicate on the "closed_at" field. +func ClosedAtLTE(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldClosedAt, v)) +} + +// ClosedAtIsNil applies the IsNil predicate on the "closed_at" field. +func ClosedAtIsNil() predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIsNull(FieldClosedAt)) +} + +// ClosedAtNotNil applies the NotNil predicate on the "closed_at" field. +func ClosedAtNotNil() predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotNull(FieldClosedAt)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.FieldLTE(FieldCreatedAt, v)) +} + +// HasLineage applies the HasEdge predicate on the "lineage" edge. +func HasLineage() predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, LineageTable, LineageColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasLineageWith applies the HasEdge predicate on the "lineage" edge with a given conditions (other predicates). +func HasLineageWith(preds ...predicate.CreditRealizationLineage) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(func(s *sql.Selector) { + step := newLineageStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.CreditRealizationLineageSegment) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.CreditRealizationLineageSegment) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.CreditRealizationLineageSegment) predicate.CreditRealizationLineageSegment { + return predicate.CreditRealizationLineageSegment(sql.NotPredicates(p)) +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment_create.go b/openmeter/ent/db/creditrealizationlineagesegment_create.go new file mode 100644 index 0000000000..eb8486a722 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment_create.go @@ -0,0 +1,772 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" +) + +// CreditRealizationLineageSegmentCreate is the builder for creating a CreditRealizationLineageSegment entity. +type CreditRealizationLineageSegmentCreate struct { + config + mutation *CreditRealizationLineageSegmentMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetLineageID sets the "lineage_id" field. +func (_c *CreditRealizationLineageSegmentCreate) SetLineageID(v string) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetLineageID(v) + return _c +} + +// SetAmount sets the "amount" field. +func (_c *CreditRealizationLineageSegmentCreate) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetAmount(v) + return _c +} + +// SetState sets the "state" field. +func (_c *CreditRealizationLineageSegmentCreate) SetState(v creditrealization.LineageSegmentState) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetState(v) + return _c +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (_c *CreditRealizationLineageSegmentCreate) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetBackingTransactionGroupID(v) + return _c +} + +// SetNillableBackingTransactionGroupID sets the "backing_transaction_group_id" field if the given value is not nil. +func (_c *CreditRealizationLineageSegmentCreate) SetNillableBackingTransactionGroupID(v *string) *CreditRealizationLineageSegmentCreate { + if v != nil { + _c.SetBackingTransactionGroupID(*v) + } + return _c +} + +// SetClosedAt sets the "closed_at" field. +func (_c *CreditRealizationLineageSegmentCreate) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetClosedAt(v) + return _c +} + +// SetNillableClosedAt sets the "closed_at" field if the given value is not nil. +func (_c *CreditRealizationLineageSegmentCreate) SetNillableClosedAt(v *time.Time) *CreditRealizationLineageSegmentCreate { + if v != nil { + _c.SetClosedAt(*v) + } + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *CreditRealizationLineageSegmentCreate) SetCreatedAt(v time.Time) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *CreditRealizationLineageSegmentCreate) SetNillableCreatedAt(v *time.Time) *CreditRealizationLineageSegmentCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetID sets the "id" field. +func (_c *CreditRealizationLineageSegmentCreate) SetID(v string) *CreditRealizationLineageSegmentCreate { + _c.mutation.SetID(v) + return _c +} + +// SetNillableID sets the "id" field if the given value is not nil. +func (_c *CreditRealizationLineageSegmentCreate) SetNillableID(v *string) *CreditRealizationLineageSegmentCreate { + if v != nil { + _c.SetID(*v) + } + return _c +} + +// SetLineage sets the "lineage" edge to the CreditRealizationLineage entity. +func (_c *CreditRealizationLineageSegmentCreate) SetLineage(v *CreditRealizationLineage) *CreditRealizationLineageSegmentCreate { + return _c.SetLineageID(v.ID) +} + +// Mutation returns the CreditRealizationLineageSegmentMutation object of the builder. +func (_c *CreditRealizationLineageSegmentCreate) Mutation() *CreditRealizationLineageSegmentMutation { + return _c.mutation +} + +// Save creates the CreditRealizationLineageSegment in the database. +func (_c *CreditRealizationLineageSegmentCreate) Save(ctx context.Context) (*CreditRealizationLineageSegment, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *CreditRealizationLineageSegmentCreate) SaveX(ctx context.Context) *CreditRealizationLineageSegment { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *CreditRealizationLineageSegmentCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *CreditRealizationLineageSegmentCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *CreditRealizationLineageSegmentCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := creditrealizationlineagesegment.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.ID(); !ok { + v := creditrealizationlineagesegment.DefaultID() + _c.mutation.SetID(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *CreditRealizationLineageSegmentCreate) check() error { + if _, ok := _c.mutation.LineageID(); !ok { + return &ValidationError{Name: "lineage_id", err: errors.New(`db: missing required field "CreditRealizationLineageSegment.lineage_id"`)} + } + if v, ok := _c.mutation.LineageID(); ok { + if err := creditrealizationlineagesegment.LineageIDValidator(v); err != nil { + return &ValidationError{Name: "lineage_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineageSegment.lineage_id": %w`, err)} + } + } + if _, ok := _c.mutation.Amount(); !ok { + return &ValidationError{Name: "amount", err: errors.New(`db: missing required field "CreditRealizationLineageSegment.amount"`)} + } + if _, ok := _c.mutation.State(); !ok { + return &ValidationError{Name: "state", err: errors.New(`db: missing required field "CreditRealizationLineageSegment.state"`)} + } + if v, ok := _c.mutation.State(); ok { + if err := creditrealizationlineagesegment.StateValidator(v); err != nil { + return &ValidationError{Name: "state", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineageSegment.state": %w`, err)} + } + } + if v, ok := _c.mutation.BackingTransactionGroupID(); ok { + if err := creditrealizationlineagesegment.BackingTransactionGroupIDValidator(v); err != nil { + return &ValidationError{Name: "backing_transaction_group_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineageSegment.backing_transaction_group_id": %w`, err)} + } + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "CreditRealizationLineageSegment.created_at"`)} + } + if len(_c.mutation.LineageIDs()) == 0 { + return &ValidationError{Name: "lineage", err: errors.New(`db: missing required edge "CreditRealizationLineageSegment.lineage"`)} + } + return nil +} + +func (_c *CreditRealizationLineageSegmentCreate) sqlSave(ctx context.Context) (*CreditRealizationLineageSegment, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(string); ok { + _node.ID = id + } else { + return nil, fmt.Errorf("unexpected CreditRealizationLineageSegment.ID type: %T", _spec.ID.Value) + } + } + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *CreditRealizationLineageSegmentCreate) createSpec() (*CreditRealizationLineageSegment, *sqlgraph.CreateSpec) { + var ( + _node = &CreditRealizationLineageSegment{config: _c.config} + _spec = sqlgraph.NewCreateSpec(creditrealizationlineagesegment.Table, sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString)) + ) + _spec.OnConflict = _c.conflict + if id, ok := _c.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if value, ok := _c.mutation.Amount(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldAmount, field.TypeOther, value) + _node.Amount = value + } + if value, ok := _c.mutation.State(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldState, field.TypeEnum, value) + _node.State = value + } + if value, ok := _c.mutation.BackingTransactionGroupID(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldBackingTransactionGroupID, field.TypeString, value) + _node.BackingTransactionGroupID = &value + } + if value, ok := _c.mutation.ClosedAt(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldClosedAt, field.TypeTime, value) + _node.ClosedAt = &value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if nodes := _c.mutation.LineageIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: creditrealizationlineagesegment.LineageTable, + Columns: []string{creditrealizationlineagesegment.LineageColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(creditrealizationlineage.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.LineageID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.CreditRealizationLineageSegment.Create(). +// SetLineageID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.CreditRealizationLineageSegmentUpsert) { +// SetLineageID(v+v). +// }). +// Exec(ctx) +func (_c *CreditRealizationLineageSegmentCreate) OnConflict(opts ...sql.ConflictOption) *CreditRealizationLineageSegmentUpsertOne { + _c.conflict = opts + return &CreditRealizationLineageSegmentUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *CreditRealizationLineageSegmentCreate) OnConflictColumns(columns ...string) *CreditRealizationLineageSegmentUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &CreditRealizationLineageSegmentUpsertOne{ + create: _c, + } +} + +type ( + // CreditRealizationLineageSegmentUpsertOne is the builder for "upsert"-ing + // one CreditRealizationLineageSegment node. + CreditRealizationLineageSegmentUpsertOne struct { + create *CreditRealizationLineageSegmentCreate + } + + // CreditRealizationLineageSegmentUpsert is the "OnConflict" setter. + CreditRealizationLineageSegmentUpsert struct { + *sql.UpdateSet + } +) + +// SetAmount sets the "amount" field. +func (u *CreditRealizationLineageSegmentUpsert) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpsert { + u.Set(creditrealizationlineagesegment.FieldAmount, v) + return u +} + +// UpdateAmount sets the "amount" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsert) UpdateAmount() *CreditRealizationLineageSegmentUpsert { + u.SetExcluded(creditrealizationlineagesegment.FieldAmount) + return u +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsert) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentUpsert { + u.Set(creditrealizationlineagesegment.FieldBackingTransactionGroupID, v) + return u +} + +// UpdateBackingTransactionGroupID sets the "backing_transaction_group_id" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsert) UpdateBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsert { + u.SetExcluded(creditrealizationlineagesegment.FieldBackingTransactionGroupID) + return u +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsert) ClearBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsert { + u.SetNull(creditrealizationlineagesegment.FieldBackingTransactionGroupID) + return u +} + +// SetClosedAt sets the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsert) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentUpsert { + u.Set(creditrealizationlineagesegment.FieldClosedAt, v) + return u +} + +// UpdateClosedAt sets the "closed_at" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsert) UpdateClosedAt() *CreditRealizationLineageSegmentUpsert { + u.SetExcluded(creditrealizationlineagesegment.FieldClosedAt) + return u +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsert) ClearClosedAt() *CreditRealizationLineageSegmentUpsert { + u.SetNull(creditrealizationlineagesegment.FieldClosedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(creditrealizationlineagesegment.FieldID) +// }), +// ). +// Exec(ctx) +func (u *CreditRealizationLineageSegmentUpsertOne) UpdateNewValues() *CreditRealizationLineageSegmentUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.ID(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldID) + } + if _, exists := u.create.mutation.LineageID(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldLineageID) + } + if _, exists := u.create.mutation.State(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldState) + } + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *CreditRealizationLineageSegmentUpsertOne) Ignore() *CreditRealizationLineageSegmentUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *CreditRealizationLineageSegmentUpsertOne) DoNothing() *CreditRealizationLineageSegmentUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the CreditRealizationLineageSegmentCreate.OnConflict +// documentation for more info. +func (u *CreditRealizationLineageSegmentUpsertOne) Update(set func(*CreditRealizationLineageSegmentUpsert)) *CreditRealizationLineageSegmentUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&CreditRealizationLineageSegmentUpsert{UpdateSet: update}) + })) + return u +} + +// SetAmount sets the "amount" field. +func (u *CreditRealizationLineageSegmentUpsertOne) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetAmount(v) + }) +} + +// UpdateAmount sets the "amount" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertOne) UpdateAmount() *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateAmount() + }) +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsertOne) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetBackingTransactionGroupID(v) + }) +} + +// UpdateBackingTransactionGroupID sets the "backing_transaction_group_id" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertOne) UpdateBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateBackingTransactionGroupID() + }) +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsertOne) ClearBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.ClearBackingTransactionGroupID() + }) +} + +// SetClosedAt sets the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsertOne) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetClosedAt(v) + }) +} + +// UpdateClosedAt sets the "closed_at" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertOne) UpdateClosedAt() *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateClosedAt() + }) +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsertOne) ClearClosedAt() *CreditRealizationLineageSegmentUpsertOne { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.ClearClosedAt() + }) +} + +// Exec executes the query. +func (u *CreditRealizationLineageSegmentUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("db: missing options for CreditRealizationLineageSegmentCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *CreditRealizationLineageSegmentUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *CreditRealizationLineageSegmentUpsertOne) ID(ctx context.Context) (id string, err error) { + if u.create.driver.Dialect() == dialect.MySQL { + // In case of "ON CONFLICT", there is no way to get back non-numeric ID + // fields from the database since MySQL does not support the RETURNING clause. + return id, errors.New("db: CreditRealizationLineageSegmentUpsertOne.ID is not supported by MySQL driver. Use CreditRealizationLineageSegmentUpsertOne.Exec instead") + } + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *CreditRealizationLineageSegmentUpsertOne) IDX(ctx context.Context) string { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// CreditRealizationLineageSegmentCreateBulk is the builder for creating many CreditRealizationLineageSegment entities in bulk. +type CreditRealizationLineageSegmentCreateBulk struct { + config + err error + builders []*CreditRealizationLineageSegmentCreate + conflict []sql.ConflictOption +} + +// Save creates the CreditRealizationLineageSegment entities in the database. +func (_c *CreditRealizationLineageSegmentCreateBulk) Save(ctx context.Context) ([]*CreditRealizationLineageSegment, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*CreditRealizationLineageSegment, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*CreditRealizationLineageSegmentMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *CreditRealizationLineageSegmentCreateBulk) SaveX(ctx context.Context) []*CreditRealizationLineageSegment { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *CreditRealizationLineageSegmentCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *CreditRealizationLineageSegmentCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.CreditRealizationLineageSegment.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.CreditRealizationLineageSegmentUpsert) { +// SetLineageID(v+v). +// }). +// Exec(ctx) +func (_c *CreditRealizationLineageSegmentCreateBulk) OnConflict(opts ...sql.ConflictOption) *CreditRealizationLineageSegmentUpsertBulk { + _c.conflict = opts + return &CreditRealizationLineageSegmentUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *CreditRealizationLineageSegmentCreateBulk) OnConflictColumns(columns ...string) *CreditRealizationLineageSegmentUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &CreditRealizationLineageSegmentUpsertBulk{ + create: _c, + } +} + +// CreditRealizationLineageSegmentUpsertBulk is the builder for "upsert"-ing +// a bulk of CreditRealizationLineageSegment nodes. +type CreditRealizationLineageSegmentUpsertBulk struct { + create *CreditRealizationLineageSegmentCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(creditrealizationlineagesegment.FieldID) +// }), +// ). +// Exec(ctx) +func (u *CreditRealizationLineageSegmentUpsertBulk) UpdateNewValues() *CreditRealizationLineageSegmentUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.ID(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldID) + } + if _, exists := b.mutation.LineageID(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldLineageID) + } + if _, exists := b.mutation.State(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldState) + } + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(creditrealizationlineagesegment.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.CreditRealizationLineageSegment.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *CreditRealizationLineageSegmentUpsertBulk) Ignore() *CreditRealizationLineageSegmentUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *CreditRealizationLineageSegmentUpsertBulk) DoNothing() *CreditRealizationLineageSegmentUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the CreditRealizationLineageSegmentCreateBulk.OnConflict +// documentation for more info. +func (u *CreditRealizationLineageSegmentUpsertBulk) Update(set func(*CreditRealizationLineageSegmentUpsert)) *CreditRealizationLineageSegmentUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&CreditRealizationLineageSegmentUpsert{UpdateSet: update}) + })) + return u +} + +// SetAmount sets the "amount" field. +func (u *CreditRealizationLineageSegmentUpsertBulk) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetAmount(v) + }) +} + +// UpdateAmount sets the "amount" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertBulk) UpdateAmount() *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateAmount() + }) +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsertBulk) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetBackingTransactionGroupID(v) + }) +} + +// UpdateBackingTransactionGroupID sets the "backing_transaction_group_id" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertBulk) UpdateBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateBackingTransactionGroupID() + }) +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (u *CreditRealizationLineageSegmentUpsertBulk) ClearBackingTransactionGroupID() *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.ClearBackingTransactionGroupID() + }) +} + +// SetClosedAt sets the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsertBulk) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.SetClosedAt(v) + }) +} + +// UpdateClosedAt sets the "closed_at" field to the value that was provided on create. +func (u *CreditRealizationLineageSegmentUpsertBulk) UpdateClosedAt() *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.UpdateClosedAt() + }) +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (u *CreditRealizationLineageSegmentUpsertBulk) ClearClosedAt() *CreditRealizationLineageSegmentUpsertBulk { + return u.Update(func(s *CreditRealizationLineageSegmentUpsert) { + s.ClearClosedAt() + }) +} + +// Exec executes the query. +func (u *CreditRealizationLineageSegmentUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("db: OnConflict was set for builder %d. Set it on the CreditRealizationLineageSegmentCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("db: missing options for CreditRealizationLineageSegmentCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *CreditRealizationLineageSegmentUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment_delete.go b/openmeter/ent/db/creditrealizationlineagesegment_delete.go new file mode 100644 index 0000000000..426d55613e --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageSegmentDelete is the builder for deleting a CreditRealizationLineageSegment entity. +type CreditRealizationLineageSegmentDelete struct { + config + hooks []Hook + mutation *CreditRealizationLineageSegmentMutation +} + +// Where appends a list predicates to the CreditRealizationLineageSegmentDelete builder. +func (_d *CreditRealizationLineageSegmentDelete) Where(ps ...predicate.CreditRealizationLineageSegment) *CreditRealizationLineageSegmentDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *CreditRealizationLineageSegmentDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *CreditRealizationLineageSegmentDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *CreditRealizationLineageSegmentDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(creditrealizationlineagesegment.Table, sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// CreditRealizationLineageSegmentDeleteOne is the builder for deleting a single CreditRealizationLineageSegment entity. +type CreditRealizationLineageSegmentDeleteOne struct { + _d *CreditRealizationLineageSegmentDelete +} + +// Where appends a list predicates to the CreditRealizationLineageSegmentDelete builder. +func (_d *CreditRealizationLineageSegmentDeleteOne) Where(ps ...predicate.CreditRealizationLineageSegment) *CreditRealizationLineageSegmentDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *CreditRealizationLineageSegmentDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{creditrealizationlineagesegment.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *CreditRealizationLineageSegmentDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment_query.go b/openmeter/ent/db/creditrealizationlineagesegment_query.go new file mode 100644 index 0000000000..8917e072ba --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageSegmentQuery is the builder for querying CreditRealizationLineageSegment entities. +type CreditRealizationLineageSegmentQuery struct { + config + ctx *QueryContext + order []creditrealizationlineagesegment.OrderOption + inters []Interceptor + predicates []predicate.CreditRealizationLineageSegment + withLineage *CreditRealizationLineageQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the CreditRealizationLineageSegmentQuery builder. +func (_q *CreditRealizationLineageSegmentQuery) Where(ps ...predicate.CreditRealizationLineageSegment) *CreditRealizationLineageSegmentQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *CreditRealizationLineageSegmentQuery) Limit(limit int) *CreditRealizationLineageSegmentQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *CreditRealizationLineageSegmentQuery) Offset(offset int) *CreditRealizationLineageSegmentQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *CreditRealizationLineageSegmentQuery) Unique(unique bool) *CreditRealizationLineageSegmentQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *CreditRealizationLineageSegmentQuery) Order(o ...creditrealizationlineagesegment.OrderOption) *CreditRealizationLineageSegmentQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryLineage chains the current query on the "lineage" edge. +func (_q *CreditRealizationLineageSegmentQuery) QueryLineage() *CreditRealizationLineageQuery { + query := (&CreditRealizationLineageClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.FieldID, selector), + sqlgraph.To(creditrealizationlineage.Table, creditrealizationlineage.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, creditrealizationlineagesegment.LineageTable, creditrealizationlineagesegment.LineageColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first CreditRealizationLineageSegment entity from the query. +// Returns a *NotFoundError when no CreditRealizationLineageSegment was found. +func (_q *CreditRealizationLineageSegmentQuery) First(ctx context.Context) (*CreditRealizationLineageSegment, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{creditrealizationlineagesegment.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) FirstX(ctx context.Context) *CreditRealizationLineageSegment { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first CreditRealizationLineageSegment ID from the query. +// Returns a *NotFoundError when no CreditRealizationLineageSegment ID was found. +func (_q *CreditRealizationLineageSegmentQuery) FirstID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{creditrealizationlineagesegment.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) FirstIDX(ctx context.Context) string { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single CreditRealizationLineageSegment entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one CreditRealizationLineageSegment entity is found. +// Returns a *NotFoundError when no CreditRealizationLineageSegment entities are found. +func (_q *CreditRealizationLineageSegmentQuery) Only(ctx context.Context) (*CreditRealizationLineageSegment, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{creditrealizationlineagesegment.Label} + default: + return nil, &NotSingularError{creditrealizationlineagesegment.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) OnlyX(ctx context.Context) *CreditRealizationLineageSegment { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only CreditRealizationLineageSegment ID in the query. +// Returns a *NotSingularError when more than one CreditRealizationLineageSegment ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *CreditRealizationLineageSegmentQuery) OnlyID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{creditrealizationlineagesegment.Label} + default: + err = &NotSingularError{creditrealizationlineagesegment.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) OnlyIDX(ctx context.Context) string { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of CreditRealizationLineageSegments. +func (_q *CreditRealizationLineageSegmentQuery) All(ctx context.Context) ([]*CreditRealizationLineageSegment, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*CreditRealizationLineageSegment, *CreditRealizationLineageSegmentQuery]() + return withInterceptors[[]*CreditRealizationLineageSegment](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) AllX(ctx context.Context) []*CreditRealizationLineageSegment { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of CreditRealizationLineageSegment IDs. +func (_q *CreditRealizationLineageSegmentQuery) IDs(ctx context.Context) (ids []string, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(creditrealizationlineagesegment.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) IDsX(ctx context.Context) []string { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *CreditRealizationLineageSegmentQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*CreditRealizationLineageSegmentQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *CreditRealizationLineageSegmentQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("db: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *CreditRealizationLineageSegmentQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the CreditRealizationLineageSegmentQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *CreditRealizationLineageSegmentQuery) Clone() *CreditRealizationLineageSegmentQuery { + if _q == nil { + return nil + } + return &CreditRealizationLineageSegmentQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]creditrealizationlineagesegment.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.CreditRealizationLineageSegment{}, _q.predicates...), + withLineage: _q.withLineage.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithLineage tells the query-builder to eager-load the nodes that are connected to +// the "lineage" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *CreditRealizationLineageSegmentQuery) WithLineage(opts ...func(*CreditRealizationLineageQuery)) *CreditRealizationLineageSegmentQuery { + query := (&CreditRealizationLineageClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withLineage = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// LineageID string `json:"lineage_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.CreditRealizationLineageSegment.Query(). +// GroupBy(creditrealizationlineagesegment.FieldLineageID). +// Aggregate(db.Count()). +// Scan(ctx, &v) +func (_q *CreditRealizationLineageSegmentQuery) GroupBy(field string, fields ...string) *CreditRealizationLineageSegmentGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &CreditRealizationLineageSegmentGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = creditrealizationlineagesegment.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// LineageID string `json:"lineage_id,omitempty"` +// } +// +// client.CreditRealizationLineageSegment.Query(). +// Select(creditrealizationlineagesegment.FieldLineageID). +// Scan(ctx, &v) +func (_q *CreditRealizationLineageSegmentQuery) Select(fields ...string) *CreditRealizationLineageSegmentSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &CreditRealizationLineageSegmentSelect{CreditRealizationLineageSegmentQuery: _q} + sbuild.label = creditrealizationlineagesegment.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a CreditRealizationLineageSegmentSelect configured with the given aggregations. +func (_q *CreditRealizationLineageSegmentQuery) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageSegmentSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *CreditRealizationLineageSegmentQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !creditrealizationlineagesegment.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *CreditRealizationLineageSegmentQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*CreditRealizationLineageSegment, error) { + var ( + nodes = []*CreditRealizationLineageSegment{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withLineage != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*CreditRealizationLineageSegment).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &CreditRealizationLineageSegment{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withLineage; query != nil { + if err := _q.loadLineage(ctx, query, nodes, nil, + func(n *CreditRealizationLineageSegment, e *CreditRealizationLineage) { n.Edges.Lineage = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *CreditRealizationLineageSegmentQuery) loadLineage(ctx context.Context, query *CreditRealizationLineageQuery, nodes []*CreditRealizationLineageSegment, init func(*CreditRealizationLineageSegment), assign func(*CreditRealizationLineageSegment, *CreditRealizationLineage)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*CreditRealizationLineageSegment) + for i := range nodes { + fk := nodes[i].LineageID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(creditrealizationlineage.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "lineage_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *CreditRealizationLineageSegmentQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *CreditRealizationLineageSegmentQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.Columns, sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, creditrealizationlineagesegment.FieldID) + for i := range fields { + if fields[i] != creditrealizationlineagesegment.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withLineage != nil { + _spec.Node.AddColumnOnce(creditrealizationlineagesegment.FieldLineageID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *CreditRealizationLineageSegmentQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(creditrealizationlineagesegment.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = creditrealizationlineagesegment.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *CreditRealizationLineageSegmentQuery) ForUpdate(opts ...sql.LockOption) *CreditRealizationLineageSegmentQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *CreditRealizationLineageSegmentQuery) ForShare(opts ...sql.LockOption) *CreditRealizationLineageSegmentQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// CreditRealizationLineageSegmentGroupBy is the group-by builder for CreditRealizationLineageSegment entities. +type CreditRealizationLineageSegmentGroupBy struct { + selector + build *CreditRealizationLineageSegmentQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *CreditRealizationLineageSegmentGroupBy) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageSegmentGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *CreditRealizationLineageSegmentGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CreditRealizationLineageSegmentQuery, *CreditRealizationLineageSegmentGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *CreditRealizationLineageSegmentGroupBy) sqlScan(ctx context.Context, root *CreditRealizationLineageSegmentQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// CreditRealizationLineageSegmentSelect is the builder for selecting fields of CreditRealizationLineageSegment entities. +type CreditRealizationLineageSegmentSelect struct { + *CreditRealizationLineageSegmentQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *CreditRealizationLineageSegmentSelect) Aggregate(fns ...AggregateFunc) *CreditRealizationLineageSegmentSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *CreditRealizationLineageSegmentSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*CreditRealizationLineageSegmentQuery, *CreditRealizationLineageSegmentSelect](ctx, _s.CreditRealizationLineageSegmentQuery, _s, _s.inters, v) +} + +func (_s *CreditRealizationLineageSegmentSelect) sqlScan(ctx context.Context, root *CreditRealizationLineageSegmentQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/openmeter/ent/db/creditrealizationlineagesegment_update.go b/openmeter/ent/db/creditrealizationlineagesegment_update.go new file mode 100644 index 0000000000..8b70ee9746 --- /dev/null +++ b/openmeter/ent/db/creditrealizationlineagesegment_update.go @@ -0,0 +1,347 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" +) + +// CreditRealizationLineageSegmentUpdate is the builder for updating CreditRealizationLineageSegment entities. +type CreditRealizationLineageSegmentUpdate struct { + config + hooks []Hook + mutation *CreditRealizationLineageSegmentMutation +} + +// Where appends a list predicates to the CreditRealizationLineageSegmentUpdate builder. +func (_u *CreditRealizationLineageSegmentUpdate) Where(ps ...predicate.CreditRealizationLineageSegment) *CreditRealizationLineageSegmentUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetAmount sets the "amount" field. +func (_u *CreditRealizationLineageSegmentUpdate) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpdate { + _u.mutation.SetAmount(v) + return _u +} + +// SetNillableAmount sets the "amount" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdate) SetNillableAmount(v *alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpdate { + if v != nil { + _u.SetAmount(*v) + } + return _u +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (_u *CreditRealizationLineageSegmentUpdate) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentUpdate { + _u.mutation.SetBackingTransactionGroupID(v) + return _u +} + +// SetNillableBackingTransactionGroupID sets the "backing_transaction_group_id" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdate) SetNillableBackingTransactionGroupID(v *string) *CreditRealizationLineageSegmentUpdate { + if v != nil { + _u.SetBackingTransactionGroupID(*v) + } + return _u +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (_u *CreditRealizationLineageSegmentUpdate) ClearBackingTransactionGroupID() *CreditRealizationLineageSegmentUpdate { + _u.mutation.ClearBackingTransactionGroupID() + return _u +} + +// SetClosedAt sets the "closed_at" field. +func (_u *CreditRealizationLineageSegmentUpdate) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentUpdate { + _u.mutation.SetClosedAt(v) + return _u +} + +// SetNillableClosedAt sets the "closed_at" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdate) SetNillableClosedAt(v *time.Time) *CreditRealizationLineageSegmentUpdate { + if v != nil { + _u.SetClosedAt(*v) + } + return _u +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (_u *CreditRealizationLineageSegmentUpdate) ClearClosedAt() *CreditRealizationLineageSegmentUpdate { + _u.mutation.ClearClosedAt() + return _u +} + +// Mutation returns the CreditRealizationLineageSegmentMutation object of the builder. +func (_u *CreditRealizationLineageSegmentUpdate) Mutation() *CreditRealizationLineageSegmentMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *CreditRealizationLineageSegmentUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *CreditRealizationLineageSegmentUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *CreditRealizationLineageSegmentUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *CreditRealizationLineageSegmentUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *CreditRealizationLineageSegmentUpdate) check() error { + if v, ok := _u.mutation.BackingTransactionGroupID(); ok { + if err := creditrealizationlineagesegment.BackingTransactionGroupIDValidator(v); err != nil { + return &ValidationError{Name: "backing_transaction_group_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineageSegment.backing_transaction_group_id": %w`, err)} + } + } + if _u.mutation.LineageCleared() && len(_u.mutation.LineageIDs()) > 0 { + return errors.New(`db: clearing a required unique edge "CreditRealizationLineageSegment.lineage"`) + } + return nil +} + +func (_u *CreditRealizationLineageSegmentUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.Columns, sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Amount(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldAmount, field.TypeOther, value) + } + if value, ok := _u.mutation.BackingTransactionGroupID(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldBackingTransactionGroupID, field.TypeString, value) + } + if _u.mutation.BackingTransactionGroupIDCleared() { + _spec.ClearField(creditrealizationlineagesegment.FieldBackingTransactionGroupID, field.TypeString) + } + if value, ok := _u.mutation.ClosedAt(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldClosedAt, field.TypeTime, value) + } + if _u.mutation.ClosedAtCleared() { + _spec.ClearField(creditrealizationlineagesegment.FieldClosedAt, field.TypeTime) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{creditrealizationlineagesegment.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// CreditRealizationLineageSegmentUpdateOne is the builder for updating a single CreditRealizationLineageSegment entity. +type CreditRealizationLineageSegmentUpdateOne struct { + config + fields []string + hooks []Hook + mutation *CreditRealizationLineageSegmentMutation +} + +// SetAmount sets the "amount" field. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetAmount(v alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.SetAmount(v) + return _u +} + +// SetNillableAmount sets the "amount" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetNillableAmount(v *alpacadecimal.Decimal) *CreditRealizationLineageSegmentUpdateOne { + if v != nil { + _u.SetAmount(*v) + } + return _u +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetBackingTransactionGroupID(v string) *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.SetBackingTransactionGroupID(v) + return _u +} + +// SetNillableBackingTransactionGroupID sets the "backing_transaction_group_id" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetNillableBackingTransactionGroupID(v *string) *CreditRealizationLineageSegmentUpdateOne { + if v != nil { + _u.SetBackingTransactionGroupID(*v) + } + return _u +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (_u *CreditRealizationLineageSegmentUpdateOne) ClearBackingTransactionGroupID() *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.ClearBackingTransactionGroupID() + return _u +} + +// SetClosedAt sets the "closed_at" field. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetClosedAt(v time.Time) *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.SetClosedAt(v) + return _u +} + +// SetNillableClosedAt sets the "closed_at" field if the given value is not nil. +func (_u *CreditRealizationLineageSegmentUpdateOne) SetNillableClosedAt(v *time.Time) *CreditRealizationLineageSegmentUpdateOne { + if v != nil { + _u.SetClosedAt(*v) + } + return _u +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (_u *CreditRealizationLineageSegmentUpdateOne) ClearClosedAt() *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.ClearClosedAt() + return _u +} + +// Mutation returns the CreditRealizationLineageSegmentMutation object of the builder. +func (_u *CreditRealizationLineageSegmentUpdateOne) Mutation() *CreditRealizationLineageSegmentMutation { + return _u.mutation +} + +// Where appends a list predicates to the CreditRealizationLineageSegmentUpdate builder. +func (_u *CreditRealizationLineageSegmentUpdateOne) Where(ps ...predicate.CreditRealizationLineageSegment) *CreditRealizationLineageSegmentUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *CreditRealizationLineageSegmentUpdateOne) Select(field string, fields ...string) *CreditRealizationLineageSegmentUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated CreditRealizationLineageSegment entity. +func (_u *CreditRealizationLineageSegmentUpdateOne) Save(ctx context.Context) (*CreditRealizationLineageSegment, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *CreditRealizationLineageSegmentUpdateOne) SaveX(ctx context.Context) *CreditRealizationLineageSegment { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *CreditRealizationLineageSegmentUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *CreditRealizationLineageSegmentUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *CreditRealizationLineageSegmentUpdateOne) check() error { + if v, ok := _u.mutation.BackingTransactionGroupID(); ok { + if err := creditrealizationlineagesegment.BackingTransactionGroupIDValidator(v); err != nil { + return &ValidationError{Name: "backing_transaction_group_id", err: fmt.Errorf(`db: validator failed for field "CreditRealizationLineageSegment.backing_transaction_group_id": %w`, err)} + } + } + if _u.mutation.LineageCleared() && len(_u.mutation.LineageIDs()) > 0 { + return errors.New(`db: clearing a required unique edge "CreditRealizationLineageSegment.lineage"`) + } + return nil +} + +func (_u *CreditRealizationLineageSegmentUpdateOne) sqlSave(ctx context.Context) (_node *CreditRealizationLineageSegment, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(creditrealizationlineagesegment.Table, creditrealizationlineagesegment.Columns, sqlgraph.NewFieldSpec(creditrealizationlineagesegment.FieldID, field.TypeString)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "CreditRealizationLineageSegment.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, creditrealizationlineagesegment.FieldID) + for _, f := range fields { + if !creditrealizationlineagesegment.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != creditrealizationlineagesegment.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Amount(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldAmount, field.TypeOther, value) + } + if value, ok := _u.mutation.BackingTransactionGroupID(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldBackingTransactionGroupID, field.TypeString, value) + } + if _u.mutation.BackingTransactionGroupIDCleared() { + _spec.ClearField(creditrealizationlineagesegment.FieldBackingTransactionGroupID, field.TypeString) + } + if value, ok := _u.mutation.ClosedAt(); ok { + _spec.SetField(creditrealizationlineagesegment.FieldClosedAt, field.TypeTime, value) + } + if _u.mutation.ClosedAtCleared() { + _spec.ClearField(creditrealizationlineagesegment.FieldClosedAt, field.TypeTime) + } + _node = &CreditRealizationLineageSegment{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{creditrealizationlineagesegment.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/openmeter/ent/db/cursor.go b/openmeter/ent/db/cursor.go index 01ed99c5d3..2572f26f7e 100644 --- a/openmeter/ent/db/cursor.go +++ b/openmeter/ent/db/cursor.go @@ -1745,6 +1745,108 @@ func (_m *ChargesSearchV1Query) Cursor(ctx context.Context, cursor *pagination.C return result, nil } +// Cursor runs the query and returns a cursor-paginated response. +// Ordering is always by created_at asc, id asc. +func (_m *CreditRealizationLineageQuery) Cursor(ctx context.Context, cursor *pagination.Cursor) (pagination.Result[*CreditRealizationLineage], error) { + if cursor != nil { + if err := cursor.Validate(); err != nil { + return pagination.Result[*CreditRealizationLineage]{}, fmt.Errorf("invalid cursor: %w", err) + } + + _m.Where(func(s *sql.Selector) { + s.Where( + sql.Or( + sql.GT(s.C("created_at"), cursor.Time), + sql.And( + sql.EQ(s.C("created_at"), cursor.Time), + sql.P(func(b *sql.Builder) { + b.WriteString("CAST(") + b.WriteString(s.C("id")) + b.WriteString(" AS TEXT) > ") + b.Args(cursor.ID) + }), + ), + ), + ) + }) + } + + _m.Order(func(s *sql.Selector) { + s.OrderBy(sql.Asc(s.C("created_at")), sql.Asc(s.C("id"))) + }) + + items, err := _m.All(ctx) + if err != nil { + return pagination.Result[*CreditRealizationLineage]{}, err + } + + if items == nil { + items = make([]*CreditRealizationLineage, 0) + } + + result := pagination.Result[*CreditRealizationLineage]{ + Items: items, + } + + if len(items) > 0 { + last := items[len(items)-1] + result.NextCursor = lo.ToPtr(pagination.NewCursor(last.CreatedAt, fmt.Sprint(last.ID))) + } + + return result, nil +} + +// Cursor runs the query and returns a cursor-paginated response. +// Ordering is always by created_at asc, id asc. +func (_m *CreditRealizationLineageSegmentQuery) Cursor(ctx context.Context, cursor *pagination.Cursor) (pagination.Result[*CreditRealizationLineageSegment], error) { + if cursor != nil { + if err := cursor.Validate(); err != nil { + return pagination.Result[*CreditRealizationLineageSegment]{}, fmt.Errorf("invalid cursor: %w", err) + } + + _m.Where(func(s *sql.Selector) { + s.Where( + sql.Or( + sql.GT(s.C("created_at"), cursor.Time), + sql.And( + sql.EQ(s.C("created_at"), cursor.Time), + sql.P(func(b *sql.Builder) { + b.WriteString("CAST(") + b.WriteString(s.C("id")) + b.WriteString(" AS TEXT) > ") + b.Args(cursor.ID) + }), + ), + ), + ) + }) + } + + _m.Order(func(s *sql.Selector) { + s.OrderBy(sql.Asc(s.C("created_at")), sql.Asc(s.C("id"))) + }) + + items, err := _m.All(ctx) + if err != nil { + return pagination.Result[*CreditRealizationLineageSegment]{}, err + } + + if items == nil { + items = make([]*CreditRealizationLineageSegment, 0) + } + + result := pagination.Result[*CreditRealizationLineageSegment]{ + Items: items, + } + + if len(items) > 0 { + last := items[len(items)-1] + result.NextCursor = lo.ToPtr(pagination.NewCursor(last.CreatedAt, fmt.Sprint(last.ID))) + } + + return result, nil +} + // Cursor runs the query and returns a cursor-paginated response. // Ordering is always by created_at asc, id asc. func (_m *CurrencyCostBasisQuery) Cursor(ctx context.Context, cursor *pagination.Cursor) (pagination.Result[*CurrencyCostBasis], error) { diff --git a/openmeter/ent/db/ent.go b/openmeter/ent/db/ent.go index eac6c20ef2..4e6a463f0d 100644 --- a/openmeter/ent/db/ent.go +++ b/openmeter/ent/db/ent.go @@ -52,6 +52,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruninvoicedusage" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedrunpayment" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruns" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" "github.com/openmeterio/openmeter/openmeter/ent/db/currencycostbasis" "github.com/openmeterio/openmeter/openmeter/ent/db/customcurrency" "github.com/openmeterio/openmeter/openmeter/ent/db/customer" @@ -188,6 +190,8 @@ func checkColumn(t, c string) error { chargeusagebasedrunpayment.Table: chargeusagebasedrunpayment.ValidColumn, chargeusagebasedruns.Table: chargeusagebasedruns.ValidColumn, chargessearchv1.Table: chargessearchv1.ValidColumn, + creditrealizationlineage.Table: creditrealizationlineage.ValidColumn, + creditrealizationlineagesegment.Table: creditrealizationlineagesegment.ValidColumn, currencycostbasis.Table: currencycostbasis.ValidColumn, customcurrency.Table: customcurrency.ValidColumn, customer.Table: customer.ValidColumn, diff --git a/openmeter/ent/db/entmixinaccessor.go b/openmeter/ent/db/entmixinaccessor.go index fb04a47fb4..94cf5a59bd 100644 --- a/openmeter/ent/db/entmixinaccessor.go +++ b/openmeter/ent/db/entmixinaccessor.go @@ -1707,6 +1707,18 @@ func (e *ChargeUsageBasedRuns) GetTotal() alpacadecimal.Decimal { return e.Total } +func (e *CreditRealizationLineage) GetID() string { + return e.ID +} + +func (e *CreditRealizationLineage) GetNamespace() string { + return e.Namespace +} + +func (e *CreditRealizationLineageSegment) GetID() string { + return e.ID +} + func (e *CurrencyCostBasis) GetID() string { return e.ID } diff --git a/openmeter/ent/db/expose.go b/openmeter/ent/db/expose.go index a6bcfd918e..d926a4d542 100644 --- a/openmeter/ent/db/expose.go +++ b/openmeter/ent/db/expose.go @@ -167,6 +167,10 @@ func NewTxClientFromRawConfig(ctx context.Context, cfg entutils.RawEntConfig) *T ChargesSearchV1: NewChargesSearchV1Client(config), + CreditRealizationLineage: NewCreditRealizationLineageClient(config), + + CreditRealizationLineageSegment: NewCreditRealizationLineageSegmentClient(config), + CurrencyCostBasis: NewCurrencyCostBasisClient(config), CustomCurrency: NewCustomCurrencyClient(config), diff --git a/openmeter/ent/db/hook/hook.go b/openmeter/ent/db/hook/hook.go index 0ac095fbf8..ab10d8e7fb 100644 --- a/openmeter/ent/db/hook/hook.go +++ b/openmeter/ent/db/hook/hook.go @@ -465,6 +465,30 @@ func (f ChargeUsageBasedRunsFunc) Mutate(ctx context.Context, m db.Mutation) (db return nil, fmt.Errorf("unexpected mutation type %T. expect *db.ChargeUsageBasedRunsMutation", m) } +// The CreditRealizationLineageFunc type is an adapter to allow the use of ordinary +// function as CreditRealizationLineage mutator. +type CreditRealizationLineageFunc func(context.Context, *db.CreditRealizationLineageMutation) (db.Value, error) + +// Mutate calls f(ctx, m). +func (f CreditRealizationLineageFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, error) { + if mv, ok := m.(*db.CreditRealizationLineageMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *db.CreditRealizationLineageMutation", m) +} + +// The CreditRealizationLineageSegmentFunc type is an adapter to allow the use of ordinary +// function as CreditRealizationLineageSegment mutator. +type CreditRealizationLineageSegmentFunc func(context.Context, *db.CreditRealizationLineageSegmentMutation) (db.Value, error) + +// Mutate calls f(ctx, m). +func (f CreditRealizationLineageSegmentFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, error) { + if mv, ok := m.(*db.CreditRealizationLineageSegmentMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *db.CreditRealizationLineageSegmentMutation", m) +} + // The CurrencyCostBasisFunc type is an adapter to allow the use of ordinary // function as CurrencyCostBasis mutator. type CurrencyCostBasisFunc func(context.Context, *db.CurrencyCostBasisMutation) (db.Value, error) diff --git a/openmeter/ent/db/migrate/schema.go b/openmeter/ent/db/migrate/schema.go index efba024019..460e3ca842 100644 --- a/openmeter/ent/db/migrate/schema.go +++ b/openmeter/ent/db/migrate/schema.go @@ -2510,6 +2510,99 @@ var ( }, }, } + // CreditRealizationLineagesColumns holds the columns for the "credit_realization_lineages" table. + CreditRealizationLineagesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeString, Unique: true, SchemaType: map[string]string{"postgres": "char(26)"}}, + {Name: "namespace", Type: field.TypeString}, + {Name: "root_realization_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, + {Name: "customer_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "origin_kind", Type: field.TypeEnum, Enums: []string{"real_credit", "advance"}}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "charge_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, + } + // CreditRealizationLineagesTable holds the schema information for the "credit_realization_lineages" table. + CreditRealizationLineagesTable = &schema.Table{ + Name: "credit_realization_lineages", + Columns: CreditRealizationLineagesColumns, + PrimaryKey: []*schema.Column{CreditRealizationLineagesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "credit_realization_lineages_charges_credit_realization_lineages", + Columns: []*schema.Column{CreditRealizationLineagesColumns[7]}, + RefColumns: []*schema.Column{ChargesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "creditrealizationlineage_id", + Unique: true, + Columns: []*schema.Column{CreditRealizationLineagesColumns[0]}, + }, + { + Name: "creditrealizationlineage_namespace", + Unique: false, + Columns: []*schema.Column{CreditRealizationLineagesColumns[1]}, + }, + { + Name: "creditrealizationlineage_namespace_root_realization_id", + Unique: true, + Columns: []*schema.Column{CreditRealizationLineagesColumns[1], CreditRealizationLineagesColumns[2]}, + }, + { + Name: "creditrealizationlineage_namespace_charge_id", + Unique: false, + Columns: []*schema.Column{CreditRealizationLineagesColumns[1], CreditRealizationLineagesColumns[7]}, + }, + { + Name: "creditrealizationlineage_namespace_customer_id", + Unique: false, + Columns: []*schema.Column{CreditRealizationLineagesColumns[1], CreditRealizationLineagesColumns[3]}, + }, + }, + } + // CreditRealizationLineageSegmentsColumns holds the columns for the "credit_realization_lineage_segments" table. + CreditRealizationLineageSegmentsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeString, Unique: true, SchemaType: map[string]string{"postgres": "char(26)"}}, + {Name: "amount", Type: field.TypeOther, SchemaType: map[string]string{"postgres": "numeric"}}, + {Name: "state", Type: field.TypeEnum, Enums: []string{"real_credit", "advance_uncovered", "advance_backfilled"}}, + {Name: "backing_transaction_group_id", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "char(26)"}}, + {Name: "closed_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "lineage_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, + } + // CreditRealizationLineageSegmentsTable holds the schema information for the "credit_realization_lineage_segments" table. + CreditRealizationLineageSegmentsTable = &schema.Table{ + Name: "credit_realization_lineage_segments", + Columns: CreditRealizationLineageSegmentsColumns, + PrimaryKey: []*schema.Column{CreditRealizationLineageSegmentsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "credit_realization_lineage_segments_credit_realization_lineages_segments", + Columns: []*schema.Column{CreditRealizationLineageSegmentsColumns[6]}, + RefColumns: []*schema.Column{CreditRealizationLineagesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "creditrealizationlineagesegment_id", + Unique: true, + Columns: []*schema.Column{CreditRealizationLineageSegmentsColumns[0]}, + }, + { + Name: "creditrealizationlineagesegment_lineage_id", + Unique: false, + Columns: []*schema.Column{CreditRealizationLineageSegmentsColumns[6]}, + }, + { + Name: "creditrealizationlineagesegment_lineage_id_closed_at", + Unique: false, + Columns: []*schema.Column{CreditRealizationLineageSegmentsColumns[6], CreditRealizationLineageSegmentsColumns[4]}, + }, + }, + } // CurrencyCostBasesColumns holds the columns for the "currency_cost_bases" table. CurrencyCostBasesColumns = []*schema.Column{ {Name: "id", Type: field.TypeString, Unique: true, SchemaType: map[string]string{"postgres": "char(26)"}}, @@ -4591,6 +4684,8 @@ var ( ChargeUsageBasedRunInvoicedUsagesTable, ChargeUsageBasedRunPaymentsTable, ChargeUsageBasedRunsTable, + CreditRealizationLineagesTable, + CreditRealizationLineageSegmentsTable, CurrencyCostBasesTable, CustomCurrenciesTable, CustomersTable, @@ -4715,6 +4810,8 @@ func init() { ChargeUsageBasedRunPaymentsTable.ForeignKeys[0].RefTable = ChargeUsageBasedRunsTable ChargeUsageBasedRunsTable.ForeignKeys[0].RefTable = ChargeUsageBasedTable ChargeUsageBasedRunsTable.ForeignKeys[1].RefTable = FeaturesTable + CreditRealizationLineagesTable.ForeignKeys[0].RefTable = ChargesTable + CreditRealizationLineageSegmentsTable.ForeignKeys[0].RefTable = CreditRealizationLineagesTable CurrencyCostBasesTable.ForeignKeys[0].RefTable = CustomCurrenciesTable CustomerSubjectsTable.ForeignKeys[0].RefTable = CustomersTable EntitlementsTable.ForeignKeys[0].RefTable = CustomersTable diff --git a/openmeter/ent/db/mutation.go b/openmeter/ent/db/mutation.go index 70faf96081..4c4f9a12ab 100644 --- a/openmeter/ent/db/mutation.go +++ b/openmeter/ent/db/mutation.go @@ -60,6 +60,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruninvoicedusage" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedrunpayment" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruns" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" "github.com/openmeterio/openmeter/openmeter/ent/db/currencycostbasis" "github.com/openmeterio/openmeter/openmeter/ent/db/customcurrency" "github.com/openmeterio/openmeter/openmeter/ent/db/customer" @@ -154,6 +156,8 @@ const ( TypeChargeUsageBasedRunPayment = "ChargeUsageBasedRunPayment" TypeChargeUsageBasedRuns = "ChargeUsageBasedRuns" TypeChargesSearchV1 = "ChargesSearchV1" + TypeCreditRealizationLineage = "CreditRealizationLineage" + TypeCreditRealizationLineageSegment = "CreditRealizationLineageSegment" TypeCurrencyCostBasis = "CurrencyCostBasis" TypeCustomCurrency = "CustomCurrency" TypeCustomer = "Customer" @@ -35376,30 +35380,33 @@ func (m *BillingWorkflowConfigMutation) ResetEdge(name string) error { // ChargeMutation represents an operation that mutates the Charge nodes in the graph. type ChargeMutation struct { config - op Op - typ string - id *string - namespace *string - created_at *time.Time - deleted_at *time.Time - unique_reference_id *string - _type *meta.ChargeType - clearedFields map[string]struct{} - flat_fee *string - clearedflat_fee bool - credit_purchase *string - clearedcredit_purchase bool - usage_based *string - clearedusage_based bool - billing_invoice_lines map[string]struct{} - removedbilling_invoice_lines map[string]struct{} - clearedbilling_invoice_lines bool - billing_split_line_groups map[string]struct{} - removedbilling_split_line_groups map[string]struct{} - clearedbilling_split_line_groups bool - done bool - oldValue func(context.Context) (*Charge, error) - predicates []predicate.Charge + op Op + typ string + id *string + namespace *string + created_at *time.Time + deleted_at *time.Time + unique_reference_id *string + _type *meta.ChargeType + clearedFields map[string]struct{} + flat_fee *string + clearedflat_fee bool + credit_purchase *string + clearedcredit_purchase bool + usage_based *string + clearedusage_based bool + billing_invoice_lines map[string]struct{} + removedbilling_invoice_lines map[string]struct{} + clearedbilling_invoice_lines bool + billing_split_line_groups map[string]struct{} + removedbilling_split_line_groups map[string]struct{} + clearedbilling_split_line_groups bool + credit_realization_lineages map[string]struct{} + removedcredit_realization_lineages map[string]struct{} + clearedcredit_realization_lineages bool + done bool + oldValue func(context.Context) (*Charge, error) + predicates []predicate.Charge } var _ ent.Mutation = (*ChargeMutation)(nil) @@ -36087,6 +36094,60 @@ func (m *ChargeMutation) ResetBillingSplitLineGroups() { m.removedbilling_split_line_groups = nil } +// AddCreditRealizationLineageIDs adds the "credit_realization_lineages" edge to the CreditRealizationLineage entity by ids. +func (m *ChargeMutation) AddCreditRealizationLineageIDs(ids ...string) { + if m.credit_realization_lineages == nil { + m.credit_realization_lineages = make(map[string]struct{}) + } + for i := range ids { + m.credit_realization_lineages[ids[i]] = struct{}{} + } +} + +// ClearCreditRealizationLineages clears the "credit_realization_lineages" edge to the CreditRealizationLineage entity. +func (m *ChargeMutation) ClearCreditRealizationLineages() { + m.clearedcredit_realization_lineages = true +} + +// CreditRealizationLineagesCleared reports if the "credit_realization_lineages" edge to the CreditRealizationLineage entity was cleared. +func (m *ChargeMutation) CreditRealizationLineagesCleared() bool { + return m.clearedcredit_realization_lineages +} + +// RemoveCreditRealizationLineageIDs removes the "credit_realization_lineages" edge to the CreditRealizationLineage entity by IDs. +func (m *ChargeMutation) RemoveCreditRealizationLineageIDs(ids ...string) { + if m.removedcredit_realization_lineages == nil { + m.removedcredit_realization_lineages = make(map[string]struct{}) + } + for i := range ids { + delete(m.credit_realization_lineages, ids[i]) + m.removedcredit_realization_lineages[ids[i]] = struct{}{} + } +} + +// RemovedCreditRealizationLineages returns the removed IDs of the "credit_realization_lineages" edge to the CreditRealizationLineage entity. +func (m *ChargeMutation) RemovedCreditRealizationLineagesIDs() (ids []string) { + for id := range m.removedcredit_realization_lineages { + ids = append(ids, id) + } + return +} + +// CreditRealizationLineagesIDs returns the "credit_realization_lineages" edge IDs in the mutation. +func (m *ChargeMutation) CreditRealizationLineagesIDs() (ids []string) { + for id := range m.credit_realization_lineages { + ids = append(ids, id) + } + return +} + +// ResetCreditRealizationLineages resets all changes to the "credit_realization_lineages" edge. +func (m *ChargeMutation) ResetCreditRealizationLineages() { + m.credit_realization_lineages = nil + m.clearedcredit_realization_lineages = false + m.removedcredit_realization_lineages = nil +} + // Where appends a list predicates to the ChargeMutation builder. func (m *ChargeMutation) Where(ps ...predicate.Charge) { m.predicates = append(m.predicates, ps...) @@ -36372,7 +36433,7 @@ func (m *ChargeMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ChargeMutation) AddedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.flat_fee != nil { edges = append(edges, charge.EdgeFlatFee) } @@ -36388,6 +36449,9 @@ func (m *ChargeMutation) AddedEdges() []string { if m.billing_split_line_groups != nil { edges = append(edges, charge.EdgeBillingSplitLineGroups) } + if m.credit_realization_lineages != nil { + edges = append(edges, charge.EdgeCreditRealizationLineages) + } return edges } @@ -36419,19 +36483,28 @@ func (m *ChargeMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case charge.EdgeCreditRealizationLineages: + ids := make([]ent.Value, 0, len(m.credit_realization_lineages)) + for id := range m.credit_realization_lineages { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ChargeMutation) RemovedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.removedbilling_invoice_lines != nil { edges = append(edges, charge.EdgeBillingInvoiceLines) } if m.removedbilling_split_line_groups != nil { edges = append(edges, charge.EdgeBillingSplitLineGroups) } + if m.removedcredit_realization_lineages != nil { + edges = append(edges, charge.EdgeCreditRealizationLineages) + } return edges } @@ -36451,13 +36524,19 @@ func (m *ChargeMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case charge.EdgeCreditRealizationLineages: + ids := make([]ent.Value, 0, len(m.removedcredit_realization_lineages)) + for id := range m.removedcredit_realization_lineages { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ChargeMutation) ClearedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.clearedflat_fee { edges = append(edges, charge.EdgeFlatFee) } @@ -36473,6 +36552,9 @@ func (m *ChargeMutation) ClearedEdges() []string { if m.clearedbilling_split_line_groups { edges = append(edges, charge.EdgeBillingSplitLineGroups) } + if m.clearedcredit_realization_lineages { + edges = append(edges, charge.EdgeCreditRealizationLineages) + } return edges } @@ -36490,6 +36572,8 @@ func (m *ChargeMutation) EdgeCleared(name string) bool { return m.clearedbilling_invoice_lines case charge.EdgeBillingSplitLineGroups: return m.clearedbilling_split_line_groups + case charge.EdgeCreditRealizationLineages: + return m.clearedcredit_realization_lineages } return false } @@ -36530,6 +36614,9 @@ func (m *ChargeMutation) ResetEdge(name string) error { case charge.EdgeBillingSplitLineGroups: m.ResetBillingSplitLineGroups() return nil + case charge.EdgeCreditRealizationLineages: + m.ResetCreditRealizationLineages() + return nil } return fmt.Errorf("unknown Charge edge %s", name) } @@ -56977,6 +57064,1498 @@ func (m *ChargeUsageBasedRunsMutation) ResetEdge(name string) error { return fmt.Errorf("unknown ChargeUsageBasedRuns edge %s", name) } +// CreditRealizationLineageMutation represents an operation that mutates the CreditRealizationLineage nodes in the graph. +type CreditRealizationLineageMutation struct { + config + op Op + typ string + id *string + namespace *string + root_realization_id *string + customer_id *string + currency *currencyx.Code + origin_kind *creditrealization.LineageOriginKind + created_at *time.Time + clearedFields map[string]struct{} + charge *string + clearedcharge bool + segments map[string]struct{} + removedsegments map[string]struct{} + clearedsegments bool + done bool + oldValue func(context.Context) (*CreditRealizationLineage, error) + predicates []predicate.CreditRealizationLineage +} + +var _ ent.Mutation = (*CreditRealizationLineageMutation)(nil) + +// creditrealizationlineageOption allows management of the mutation configuration using functional options. +type creditrealizationlineageOption func(*CreditRealizationLineageMutation) + +// newCreditRealizationLineageMutation creates new mutation for the CreditRealizationLineage entity. +func newCreditRealizationLineageMutation(c config, op Op, opts ...creditrealizationlineageOption) *CreditRealizationLineageMutation { + m := &CreditRealizationLineageMutation{ + config: c, + op: op, + typ: TypeCreditRealizationLineage, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withCreditRealizationLineageID sets the ID field of the mutation. +func withCreditRealizationLineageID(id string) creditrealizationlineageOption { + return func(m *CreditRealizationLineageMutation) { + var ( + err error + once sync.Once + value *CreditRealizationLineage + ) + m.oldValue = func(ctx context.Context) (*CreditRealizationLineage, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().CreditRealizationLineage.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withCreditRealizationLineage sets the old CreditRealizationLineage of the mutation. +func withCreditRealizationLineage(node *CreditRealizationLineage) creditrealizationlineageOption { + return func(m *CreditRealizationLineageMutation) { + m.oldValue = func(context.Context) (*CreditRealizationLineage, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m CreditRealizationLineageMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m CreditRealizationLineageMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("db: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of CreditRealizationLineage entities. +func (m *CreditRealizationLineageMutation) SetID(id string) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *CreditRealizationLineageMutation) ID() (id string, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *CreditRealizationLineageMutation) IDs(ctx context.Context) ([]string, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []string{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().CreditRealizationLineage.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetNamespace sets the "namespace" field. +func (m *CreditRealizationLineageMutation) SetNamespace(s string) { + m.namespace = &s +} + +// Namespace returns the value of the "namespace" field in the mutation. +func (m *CreditRealizationLineageMutation) Namespace() (r string, exists bool) { + v := m.namespace + if v == nil { + return + } + return *v, true +} + +// OldNamespace returns the old "namespace" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldNamespace(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldNamespace is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldNamespace requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldNamespace: %w", err) + } + return oldValue.Namespace, nil +} + +// ResetNamespace resets all changes to the "namespace" field. +func (m *CreditRealizationLineageMutation) ResetNamespace() { + m.namespace = nil +} + +// SetChargeID sets the "charge_id" field. +func (m *CreditRealizationLineageMutation) SetChargeID(s string) { + m.charge = &s +} + +// ChargeID returns the value of the "charge_id" field in the mutation. +func (m *CreditRealizationLineageMutation) ChargeID() (r string, exists bool) { + v := m.charge + if v == nil { + return + } + return *v, true +} + +// OldChargeID returns the old "charge_id" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldChargeID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChargeID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChargeID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChargeID: %w", err) + } + return oldValue.ChargeID, nil +} + +// ResetChargeID resets all changes to the "charge_id" field. +func (m *CreditRealizationLineageMutation) ResetChargeID() { + m.charge = nil +} + +// SetRootRealizationID sets the "root_realization_id" field. +func (m *CreditRealizationLineageMutation) SetRootRealizationID(s string) { + m.root_realization_id = &s +} + +// RootRealizationID returns the value of the "root_realization_id" field in the mutation. +func (m *CreditRealizationLineageMutation) RootRealizationID() (r string, exists bool) { + v := m.root_realization_id + if v == nil { + return + } + return *v, true +} + +// OldRootRealizationID returns the old "root_realization_id" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldRootRealizationID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRootRealizationID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRootRealizationID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRootRealizationID: %w", err) + } + return oldValue.RootRealizationID, nil +} + +// ResetRootRealizationID resets all changes to the "root_realization_id" field. +func (m *CreditRealizationLineageMutation) ResetRootRealizationID() { + m.root_realization_id = nil +} + +// SetCustomerID sets the "customer_id" field. +func (m *CreditRealizationLineageMutation) SetCustomerID(s string) { + m.customer_id = &s +} + +// CustomerID returns the value of the "customer_id" field in the mutation. +func (m *CreditRealizationLineageMutation) CustomerID() (r string, exists bool) { + v := m.customer_id + if v == nil { + return + } + return *v, true +} + +// OldCustomerID returns the old "customer_id" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldCustomerID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCustomerID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCustomerID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCustomerID: %w", err) + } + return oldValue.CustomerID, nil +} + +// ResetCustomerID resets all changes to the "customer_id" field. +func (m *CreditRealizationLineageMutation) ResetCustomerID() { + m.customer_id = nil +} + +// SetCurrency sets the "currency" field. +func (m *CreditRealizationLineageMutation) SetCurrency(c currencyx.Code) { + m.currency = &c +} + +// Currency returns the value of the "currency" field in the mutation. +func (m *CreditRealizationLineageMutation) Currency() (r currencyx.Code, exists bool) { + v := m.currency + if v == nil { + return + } + return *v, true +} + +// OldCurrency returns the old "currency" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldCurrency(ctx context.Context) (v currencyx.Code, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCurrency is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCurrency requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCurrency: %w", err) + } + return oldValue.Currency, nil +} + +// ResetCurrency resets all changes to the "currency" field. +func (m *CreditRealizationLineageMutation) ResetCurrency() { + m.currency = nil +} + +// SetOriginKind sets the "origin_kind" field. +func (m *CreditRealizationLineageMutation) SetOriginKind(cok creditrealization.LineageOriginKind) { + m.origin_kind = &cok +} + +// OriginKind returns the value of the "origin_kind" field in the mutation. +func (m *CreditRealizationLineageMutation) OriginKind() (r creditrealization.LineageOriginKind, exists bool) { + v := m.origin_kind + if v == nil { + return + } + return *v, true +} + +// OldOriginKind returns the old "origin_kind" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldOriginKind(ctx context.Context) (v creditrealization.LineageOriginKind, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOriginKind is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOriginKind requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOriginKind: %w", err) + } + return oldValue.OriginKind, nil +} + +// ResetOriginKind resets all changes to the "origin_kind" field. +func (m *CreditRealizationLineageMutation) ResetOriginKind() { + m.origin_kind = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *CreditRealizationLineageMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *CreditRealizationLineageMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the CreditRealizationLineage entity. +// If the CreditRealizationLineage object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *CreditRealizationLineageMutation) ResetCreatedAt() { + m.created_at = nil +} + +// ClearCharge clears the "charge" edge to the Charge entity. +func (m *CreditRealizationLineageMutation) ClearCharge() { + m.clearedcharge = true + m.clearedFields[creditrealizationlineage.FieldChargeID] = struct{}{} +} + +// ChargeCleared reports if the "charge" edge to the Charge entity was cleared. +func (m *CreditRealizationLineageMutation) ChargeCleared() bool { + return m.clearedcharge +} + +// ChargeIDs returns the "charge" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ChargeID instead. It exists only for internal usage by the builders. +func (m *CreditRealizationLineageMutation) ChargeIDs() (ids []string) { + if id := m.charge; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetCharge resets all changes to the "charge" edge. +func (m *CreditRealizationLineageMutation) ResetCharge() { + m.charge = nil + m.clearedcharge = false +} + +// AddSegmentIDs adds the "segments" edge to the CreditRealizationLineageSegment entity by ids. +func (m *CreditRealizationLineageMutation) AddSegmentIDs(ids ...string) { + if m.segments == nil { + m.segments = make(map[string]struct{}) + } + for i := range ids { + m.segments[ids[i]] = struct{}{} + } +} + +// ClearSegments clears the "segments" edge to the CreditRealizationLineageSegment entity. +func (m *CreditRealizationLineageMutation) ClearSegments() { + m.clearedsegments = true +} + +// SegmentsCleared reports if the "segments" edge to the CreditRealizationLineageSegment entity was cleared. +func (m *CreditRealizationLineageMutation) SegmentsCleared() bool { + return m.clearedsegments +} + +// RemoveSegmentIDs removes the "segments" edge to the CreditRealizationLineageSegment entity by IDs. +func (m *CreditRealizationLineageMutation) RemoveSegmentIDs(ids ...string) { + if m.removedsegments == nil { + m.removedsegments = make(map[string]struct{}) + } + for i := range ids { + delete(m.segments, ids[i]) + m.removedsegments[ids[i]] = struct{}{} + } +} + +// RemovedSegments returns the removed IDs of the "segments" edge to the CreditRealizationLineageSegment entity. +func (m *CreditRealizationLineageMutation) RemovedSegmentsIDs() (ids []string) { + for id := range m.removedsegments { + ids = append(ids, id) + } + return +} + +// SegmentsIDs returns the "segments" edge IDs in the mutation. +func (m *CreditRealizationLineageMutation) SegmentsIDs() (ids []string) { + for id := range m.segments { + ids = append(ids, id) + } + return +} + +// ResetSegments resets all changes to the "segments" edge. +func (m *CreditRealizationLineageMutation) ResetSegments() { + m.segments = nil + m.clearedsegments = false + m.removedsegments = nil +} + +// Where appends a list predicates to the CreditRealizationLineageMutation builder. +func (m *CreditRealizationLineageMutation) Where(ps ...predicate.CreditRealizationLineage) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the CreditRealizationLineageMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *CreditRealizationLineageMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.CreditRealizationLineage, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *CreditRealizationLineageMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *CreditRealizationLineageMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (CreditRealizationLineage). +func (m *CreditRealizationLineageMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *CreditRealizationLineageMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.namespace != nil { + fields = append(fields, creditrealizationlineage.FieldNamespace) + } + if m.charge != nil { + fields = append(fields, creditrealizationlineage.FieldChargeID) + } + if m.root_realization_id != nil { + fields = append(fields, creditrealizationlineage.FieldRootRealizationID) + } + if m.customer_id != nil { + fields = append(fields, creditrealizationlineage.FieldCustomerID) + } + if m.currency != nil { + fields = append(fields, creditrealizationlineage.FieldCurrency) + } + if m.origin_kind != nil { + fields = append(fields, creditrealizationlineage.FieldOriginKind) + } + if m.created_at != nil { + fields = append(fields, creditrealizationlineage.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *CreditRealizationLineageMutation) Field(name string) (ent.Value, bool) { + switch name { + case creditrealizationlineage.FieldNamespace: + return m.Namespace() + case creditrealizationlineage.FieldChargeID: + return m.ChargeID() + case creditrealizationlineage.FieldRootRealizationID: + return m.RootRealizationID() + case creditrealizationlineage.FieldCustomerID: + return m.CustomerID() + case creditrealizationlineage.FieldCurrency: + return m.Currency() + case creditrealizationlineage.FieldOriginKind: + return m.OriginKind() + case creditrealizationlineage.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *CreditRealizationLineageMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case creditrealizationlineage.FieldNamespace: + return m.OldNamespace(ctx) + case creditrealizationlineage.FieldChargeID: + return m.OldChargeID(ctx) + case creditrealizationlineage.FieldRootRealizationID: + return m.OldRootRealizationID(ctx) + case creditrealizationlineage.FieldCustomerID: + return m.OldCustomerID(ctx) + case creditrealizationlineage.FieldCurrency: + return m.OldCurrency(ctx) + case creditrealizationlineage.FieldOriginKind: + return m.OldOriginKind(ctx) + case creditrealizationlineage.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown CreditRealizationLineage field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CreditRealizationLineageMutation) SetField(name string, value ent.Value) error { + switch name { + case creditrealizationlineage.FieldNamespace: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetNamespace(v) + return nil + case creditrealizationlineage.FieldChargeID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChargeID(v) + return nil + case creditrealizationlineage.FieldRootRealizationID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRootRealizationID(v) + return nil + case creditrealizationlineage.FieldCustomerID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomerID(v) + return nil + case creditrealizationlineage.FieldCurrency: + v, ok := value.(currencyx.Code) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCurrency(v) + return nil + case creditrealizationlineage.FieldOriginKind: + v, ok := value.(creditrealization.LineageOriginKind) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOriginKind(v) + return nil + case creditrealizationlineage.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown CreditRealizationLineage field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *CreditRealizationLineageMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *CreditRealizationLineageMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CreditRealizationLineageMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown CreditRealizationLineage numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *CreditRealizationLineageMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *CreditRealizationLineageMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *CreditRealizationLineageMutation) ClearField(name string) error { + return fmt.Errorf("unknown CreditRealizationLineage nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *CreditRealizationLineageMutation) ResetField(name string) error { + switch name { + case creditrealizationlineage.FieldNamespace: + m.ResetNamespace() + return nil + case creditrealizationlineage.FieldChargeID: + m.ResetChargeID() + return nil + case creditrealizationlineage.FieldRootRealizationID: + m.ResetRootRealizationID() + return nil + case creditrealizationlineage.FieldCustomerID: + m.ResetCustomerID() + return nil + case creditrealizationlineage.FieldCurrency: + m.ResetCurrency() + return nil + case creditrealizationlineage.FieldOriginKind: + m.ResetOriginKind() + return nil + case creditrealizationlineage.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineage field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *CreditRealizationLineageMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.charge != nil { + edges = append(edges, creditrealizationlineage.EdgeCharge) + } + if m.segments != nil { + edges = append(edges, creditrealizationlineage.EdgeSegments) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *CreditRealizationLineageMutation) AddedIDs(name string) []ent.Value { + switch name { + case creditrealizationlineage.EdgeCharge: + if id := m.charge; id != nil { + return []ent.Value{*id} + } + case creditrealizationlineage.EdgeSegments: + ids := make([]ent.Value, 0, len(m.segments)) + for id := range m.segments { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *CreditRealizationLineageMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + if m.removedsegments != nil { + edges = append(edges, creditrealizationlineage.EdgeSegments) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *CreditRealizationLineageMutation) RemovedIDs(name string) []ent.Value { + switch name { + case creditrealizationlineage.EdgeSegments: + ids := make([]ent.Value, 0, len(m.removedsegments)) + for id := range m.removedsegments { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *CreditRealizationLineageMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedcharge { + edges = append(edges, creditrealizationlineage.EdgeCharge) + } + if m.clearedsegments { + edges = append(edges, creditrealizationlineage.EdgeSegments) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *CreditRealizationLineageMutation) EdgeCleared(name string) bool { + switch name { + case creditrealizationlineage.EdgeCharge: + return m.clearedcharge + case creditrealizationlineage.EdgeSegments: + return m.clearedsegments + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *CreditRealizationLineageMutation) ClearEdge(name string) error { + switch name { + case creditrealizationlineage.EdgeCharge: + m.ClearCharge() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineage unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *CreditRealizationLineageMutation) ResetEdge(name string) error { + switch name { + case creditrealizationlineage.EdgeCharge: + m.ResetCharge() + return nil + case creditrealizationlineage.EdgeSegments: + m.ResetSegments() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineage edge %s", name) +} + +// CreditRealizationLineageSegmentMutation represents an operation that mutates the CreditRealizationLineageSegment nodes in the graph. +type CreditRealizationLineageSegmentMutation struct { + config + op Op + typ string + id *string + amount *alpacadecimal.Decimal + state *creditrealization.LineageSegmentState + backing_transaction_group_id *string + closed_at *time.Time + created_at *time.Time + clearedFields map[string]struct{} + lineage *string + clearedlineage bool + done bool + oldValue func(context.Context) (*CreditRealizationLineageSegment, error) + predicates []predicate.CreditRealizationLineageSegment +} + +var _ ent.Mutation = (*CreditRealizationLineageSegmentMutation)(nil) + +// creditrealizationlineagesegmentOption allows management of the mutation configuration using functional options. +type creditrealizationlineagesegmentOption func(*CreditRealizationLineageSegmentMutation) + +// newCreditRealizationLineageSegmentMutation creates new mutation for the CreditRealizationLineageSegment entity. +func newCreditRealizationLineageSegmentMutation(c config, op Op, opts ...creditrealizationlineagesegmentOption) *CreditRealizationLineageSegmentMutation { + m := &CreditRealizationLineageSegmentMutation{ + config: c, + op: op, + typ: TypeCreditRealizationLineageSegment, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withCreditRealizationLineageSegmentID sets the ID field of the mutation. +func withCreditRealizationLineageSegmentID(id string) creditrealizationlineagesegmentOption { + return func(m *CreditRealizationLineageSegmentMutation) { + var ( + err error + once sync.Once + value *CreditRealizationLineageSegment + ) + m.oldValue = func(ctx context.Context) (*CreditRealizationLineageSegment, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().CreditRealizationLineageSegment.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withCreditRealizationLineageSegment sets the old CreditRealizationLineageSegment of the mutation. +func withCreditRealizationLineageSegment(node *CreditRealizationLineageSegment) creditrealizationlineagesegmentOption { + return func(m *CreditRealizationLineageSegmentMutation) { + m.oldValue = func(context.Context) (*CreditRealizationLineageSegment, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m CreditRealizationLineageSegmentMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m CreditRealizationLineageSegmentMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("db: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of CreditRealizationLineageSegment entities. +func (m *CreditRealizationLineageSegmentMutation) SetID(id string) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *CreditRealizationLineageSegmentMutation) ID() (id string, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *CreditRealizationLineageSegmentMutation) IDs(ctx context.Context) ([]string, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []string{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().CreditRealizationLineageSegment.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetLineageID sets the "lineage_id" field. +func (m *CreditRealizationLineageSegmentMutation) SetLineageID(s string) { + m.lineage = &s +} + +// LineageID returns the value of the "lineage_id" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) LineageID() (r string, exists bool) { + v := m.lineage + if v == nil { + return + } + return *v, true +} + +// OldLineageID returns the old "lineage_id" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldLineageID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLineageID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLineageID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLineageID: %w", err) + } + return oldValue.LineageID, nil +} + +// ResetLineageID resets all changes to the "lineage_id" field. +func (m *CreditRealizationLineageSegmentMutation) ResetLineageID() { + m.lineage = nil +} + +// SetAmount sets the "amount" field. +func (m *CreditRealizationLineageSegmentMutation) SetAmount(a alpacadecimal.Decimal) { + m.amount = &a +} + +// Amount returns the value of the "amount" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) Amount() (r alpacadecimal.Decimal, exists bool) { + v := m.amount + if v == nil { + return + } + return *v, true +} + +// OldAmount returns the old "amount" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldAmount(ctx context.Context) (v alpacadecimal.Decimal, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAmount: %w", err) + } + return oldValue.Amount, nil +} + +// ResetAmount resets all changes to the "amount" field. +func (m *CreditRealizationLineageSegmentMutation) ResetAmount() { + m.amount = nil +} + +// SetState sets the "state" field. +func (m *CreditRealizationLineageSegmentMutation) SetState(css creditrealization.LineageSegmentState) { + m.state = &css +} + +// State returns the value of the "state" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) State() (r creditrealization.LineageSegmentState, exists bool) { + v := m.state + if v == nil { + return + } + return *v, true +} + +// OldState returns the old "state" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldState(ctx context.Context) (v creditrealization.LineageSegmentState, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldState: %w", err) + } + return oldValue.State, nil +} + +// ResetState resets all changes to the "state" field. +func (m *CreditRealizationLineageSegmentMutation) ResetState() { + m.state = nil +} + +// SetBackingTransactionGroupID sets the "backing_transaction_group_id" field. +func (m *CreditRealizationLineageSegmentMutation) SetBackingTransactionGroupID(s string) { + m.backing_transaction_group_id = &s +} + +// BackingTransactionGroupID returns the value of the "backing_transaction_group_id" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) BackingTransactionGroupID() (r string, exists bool) { + v := m.backing_transaction_group_id + if v == nil { + return + } + return *v, true +} + +// OldBackingTransactionGroupID returns the old "backing_transaction_group_id" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldBackingTransactionGroupID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBackingTransactionGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBackingTransactionGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBackingTransactionGroupID: %w", err) + } + return oldValue.BackingTransactionGroupID, nil +} + +// ClearBackingTransactionGroupID clears the value of the "backing_transaction_group_id" field. +func (m *CreditRealizationLineageSegmentMutation) ClearBackingTransactionGroupID() { + m.backing_transaction_group_id = nil + m.clearedFields[creditrealizationlineagesegment.FieldBackingTransactionGroupID] = struct{}{} +} + +// BackingTransactionGroupIDCleared returns if the "backing_transaction_group_id" field was cleared in this mutation. +func (m *CreditRealizationLineageSegmentMutation) BackingTransactionGroupIDCleared() bool { + _, ok := m.clearedFields[creditrealizationlineagesegment.FieldBackingTransactionGroupID] + return ok +} + +// ResetBackingTransactionGroupID resets all changes to the "backing_transaction_group_id" field. +func (m *CreditRealizationLineageSegmentMutation) ResetBackingTransactionGroupID() { + m.backing_transaction_group_id = nil + delete(m.clearedFields, creditrealizationlineagesegment.FieldBackingTransactionGroupID) +} + +// SetClosedAt sets the "closed_at" field. +func (m *CreditRealizationLineageSegmentMutation) SetClosedAt(t time.Time) { + m.closed_at = &t +} + +// ClosedAt returns the value of the "closed_at" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) ClosedAt() (r time.Time, exists bool) { + v := m.closed_at + if v == nil { + return + } + return *v, true +} + +// OldClosedAt returns the old "closed_at" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldClosedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClosedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClosedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClosedAt: %w", err) + } + return oldValue.ClosedAt, nil +} + +// ClearClosedAt clears the value of the "closed_at" field. +func (m *CreditRealizationLineageSegmentMutation) ClearClosedAt() { + m.closed_at = nil + m.clearedFields[creditrealizationlineagesegment.FieldClosedAt] = struct{}{} +} + +// ClosedAtCleared returns if the "closed_at" field was cleared in this mutation. +func (m *CreditRealizationLineageSegmentMutation) ClosedAtCleared() bool { + _, ok := m.clearedFields[creditrealizationlineagesegment.FieldClosedAt] + return ok +} + +// ResetClosedAt resets all changes to the "closed_at" field. +func (m *CreditRealizationLineageSegmentMutation) ResetClosedAt() { + m.closed_at = nil + delete(m.clearedFields, creditrealizationlineagesegment.FieldClosedAt) +} + +// SetCreatedAt sets the "created_at" field. +func (m *CreditRealizationLineageSegmentMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *CreditRealizationLineageSegmentMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the CreditRealizationLineageSegment entity. +// If the CreditRealizationLineageSegment object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CreditRealizationLineageSegmentMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *CreditRealizationLineageSegmentMutation) ResetCreatedAt() { + m.created_at = nil +} + +// ClearLineage clears the "lineage" edge to the CreditRealizationLineage entity. +func (m *CreditRealizationLineageSegmentMutation) ClearLineage() { + m.clearedlineage = true + m.clearedFields[creditrealizationlineagesegment.FieldLineageID] = struct{}{} +} + +// LineageCleared reports if the "lineage" edge to the CreditRealizationLineage entity was cleared. +func (m *CreditRealizationLineageSegmentMutation) LineageCleared() bool { + return m.clearedlineage +} + +// LineageIDs returns the "lineage" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// LineageID instead. It exists only for internal usage by the builders. +func (m *CreditRealizationLineageSegmentMutation) LineageIDs() (ids []string) { + if id := m.lineage; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetLineage resets all changes to the "lineage" edge. +func (m *CreditRealizationLineageSegmentMutation) ResetLineage() { + m.lineage = nil + m.clearedlineage = false +} + +// Where appends a list predicates to the CreditRealizationLineageSegmentMutation builder. +func (m *CreditRealizationLineageSegmentMutation) Where(ps ...predicate.CreditRealizationLineageSegment) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the CreditRealizationLineageSegmentMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *CreditRealizationLineageSegmentMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.CreditRealizationLineageSegment, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *CreditRealizationLineageSegmentMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *CreditRealizationLineageSegmentMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (CreditRealizationLineageSegment). +func (m *CreditRealizationLineageSegmentMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *CreditRealizationLineageSegmentMutation) Fields() []string { + fields := make([]string, 0, 6) + if m.lineage != nil { + fields = append(fields, creditrealizationlineagesegment.FieldLineageID) + } + if m.amount != nil { + fields = append(fields, creditrealizationlineagesegment.FieldAmount) + } + if m.state != nil { + fields = append(fields, creditrealizationlineagesegment.FieldState) + } + if m.backing_transaction_group_id != nil { + fields = append(fields, creditrealizationlineagesegment.FieldBackingTransactionGroupID) + } + if m.closed_at != nil { + fields = append(fields, creditrealizationlineagesegment.FieldClosedAt) + } + if m.created_at != nil { + fields = append(fields, creditrealizationlineagesegment.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *CreditRealizationLineageSegmentMutation) Field(name string) (ent.Value, bool) { + switch name { + case creditrealizationlineagesegment.FieldLineageID: + return m.LineageID() + case creditrealizationlineagesegment.FieldAmount: + return m.Amount() + case creditrealizationlineagesegment.FieldState: + return m.State() + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + return m.BackingTransactionGroupID() + case creditrealizationlineagesegment.FieldClosedAt: + return m.ClosedAt() + case creditrealizationlineagesegment.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *CreditRealizationLineageSegmentMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case creditrealizationlineagesegment.FieldLineageID: + return m.OldLineageID(ctx) + case creditrealizationlineagesegment.FieldAmount: + return m.OldAmount(ctx) + case creditrealizationlineagesegment.FieldState: + return m.OldState(ctx) + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + return m.OldBackingTransactionGroupID(ctx) + case creditrealizationlineagesegment.FieldClosedAt: + return m.OldClosedAt(ctx) + case creditrealizationlineagesegment.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown CreditRealizationLineageSegment field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CreditRealizationLineageSegmentMutation) SetField(name string, value ent.Value) error { + switch name { + case creditrealizationlineagesegment.FieldLineageID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLineageID(v) + return nil + case creditrealizationlineagesegment.FieldAmount: + v, ok := value.(alpacadecimal.Decimal) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAmount(v) + return nil + case creditrealizationlineagesegment.FieldState: + v, ok := value.(creditrealization.LineageSegmentState) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetState(v) + return nil + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBackingTransactionGroupID(v) + return nil + case creditrealizationlineagesegment.FieldClosedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClosedAt(v) + return nil + case creditrealizationlineagesegment.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown CreditRealizationLineageSegment field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *CreditRealizationLineageSegmentMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *CreditRealizationLineageSegmentMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CreditRealizationLineageSegmentMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown CreditRealizationLineageSegment numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *CreditRealizationLineageSegmentMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(creditrealizationlineagesegment.FieldBackingTransactionGroupID) { + fields = append(fields, creditrealizationlineagesegment.FieldBackingTransactionGroupID) + } + if m.FieldCleared(creditrealizationlineagesegment.FieldClosedAt) { + fields = append(fields, creditrealizationlineagesegment.FieldClosedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *CreditRealizationLineageSegmentMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *CreditRealizationLineageSegmentMutation) ClearField(name string) error { + switch name { + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + m.ClearBackingTransactionGroupID() + return nil + case creditrealizationlineagesegment.FieldClosedAt: + m.ClearClosedAt() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineageSegment nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *CreditRealizationLineageSegmentMutation) ResetField(name string) error { + switch name { + case creditrealizationlineagesegment.FieldLineageID: + m.ResetLineageID() + return nil + case creditrealizationlineagesegment.FieldAmount: + m.ResetAmount() + return nil + case creditrealizationlineagesegment.FieldState: + m.ResetState() + return nil + case creditrealizationlineagesegment.FieldBackingTransactionGroupID: + m.ResetBackingTransactionGroupID() + return nil + case creditrealizationlineagesegment.FieldClosedAt: + m.ResetClosedAt() + return nil + case creditrealizationlineagesegment.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineageSegment field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *CreditRealizationLineageSegmentMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.lineage != nil { + edges = append(edges, creditrealizationlineagesegment.EdgeLineage) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *CreditRealizationLineageSegmentMutation) AddedIDs(name string) []ent.Value { + switch name { + case creditrealizationlineagesegment.EdgeLineage: + if id := m.lineage; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *CreditRealizationLineageSegmentMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *CreditRealizationLineageSegmentMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *CreditRealizationLineageSegmentMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedlineage { + edges = append(edges, creditrealizationlineagesegment.EdgeLineage) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *CreditRealizationLineageSegmentMutation) EdgeCleared(name string) bool { + switch name { + case creditrealizationlineagesegment.EdgeLineage: + return m.clearedlineage + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *CreditRealizationLineageSegmentMutation) ClearEdge(name string) error { + switch name { + case creditrealizationlineagesegment.EdgeLineage: + m.ClearLineage() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineageSegment unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *CreditRealizationLineageSegmentMutation) ResetEdge(name string) error { + switch name { + case creditrealizationlineagesegment.EdgeLineage: + m.ResetLineage() + return nil + } + return fmt.Errorf("unknown CreditRealizationLineageSegment edge %s", name) +} + // CurrencyCostBasisMutation represents an operation that mutates the CurrencyCostBasis nodes in the graph. type CurrencyCostBasisMutation struct { config diff --git a/openmeter/ent/db/paginate.go b/openmeter/ent/db/paginate.go index afd710d1d6..753cf1dccd 100644 --- a/openmeter/ent/db/paginate.go +++ b/openmeter/ent/db/paginate.go @@ -2271,6 +2271,122 @@ func (_m *ChargesSearchV1Query) Paginate(ctx context.Context, page pagination.Pa // type check var _ pagination.Paginator[*ChargesSearchV1] = (*ChargesSearchV1Query)(nil) +// Paginate runs the query and returns a paginated response. +// If page is its 0 value then it will return all the items and populate the response page accordingly. +func (_m *CreditRealizationLineageQuery) Paginate(ctx context.Context, page pagination.Page) (pagination.Result[*CreditRealizationLineage], error) { + // Get the limit and offset + limit, offset := page.Limit(), page.Offset() + + // Unset previous pagination settings + zero := 0 + _m.ctx.Offset = &zero + _m.ctx.Limit = &zero + + // Create duplicate of the query to run for + countQuery := _m.Clone() + pagedQuery := _m + + // Unset select for count query + countQuery.ctx.Fields = []string{} + + // Unset ordering for count query + countQuery.order = nil + + pagedResponse := pagination.Result[*CreditRealizationLineage]{ + Page: page, + } + + // Get the total count + count, err := countQuery.Count(ctx) + if err != nil { + return pagedResponse, fmt.Errorf("failed to get count: %w", err) + } + pagedResponse.TotalCount = count + + // If there are no items, return the empty response early + if count == 0 { + // Items should be [] not null. + pagedResponse.Items = make([]*CreditRealizationLineage, 0) + return pagedResponse, nil + } + + // If page is its 0 value then return all the items + if page.IsZero() { + offset = 0 + limit = count + } + + // Set the limit and offset + pagedQuery.ctx.Limit = &limit + pagedQuery.ctx.Offset = &offset + + // Get the paged items + items, err := pagedQuery.All(ctx) + pagedResponse.Items = items + return pagedResponse, err +} + +// type check +var _ pagination.Paginator[*CreditRealizationLineage] = (*CreditRealizationLineageQuery)(nil) + +// Paginate runs the query and returns a paginated response. +// If page is its 0 value then it will return all the items and populate the response page accordingly. +func (_m *CreditRealizationLineageSegmentQuery) Paginate(ctx context.Context, page pagination.Page) (pagination.Result[*CreditRealizationLineageSegment], error) { + // Get the limit and offset + limit, offset := page.Limit(), page.Offset() + + // Unset previous pagination settings + zero := 0 + _m.ctx.Offset = &zero + _m.ctx.Limit = &zero + + // Create duplicate of the query to run for + countQuery := _m.Clone() + pagedQuery := _m + + // Unset select for count query + countQuery.ctx.Fields = []string{} + + // Unset ordering for count query + countQuery.order = nil + + pagedResponse := pagination.Result[*CreditRealizationLineageSegment]{ + Page: page, + } + + // Get the total count + count, err := countQuery.Count(ctx) + if err != nil { + return pagedResponse, fmt.Errorf("failed to get count: %w", err) + } + pagedResponse.TotalCount = count + + // If there are no items, return the empty response early + if count == 0 { + // Items should be [] not null. + pagedResponse.Items = make([]*CreditRealizationLineageSegment, 0) + return pagedResponse, nil + } + + // If page is its 0 value then return all the items + if page.IsZero() { + offset = 0 + limit = count + } + + // Set the limit and offset + pagedQuery.ctx.Limit = &limit + pagedQuery.ctx.Offset = &offset + + // Get the paged items + items, err := pagedQuery.All(ctx) + pagedResponse.Items = items + return pagedResponse, err +} + +// type check +var _ pagination.Paginator[*CreditRealizationLineageSegment] = (*CreditRealizationLineageSegmentQuery)(nil) + // Paginate runs the query and returns a paginated response. // If page is its 0 value then it will return all the items and populate the response page accordingly. func (_m *CurrencyCostBasisQuery) Paginate(ctx context.Context, page pagination.Page) (pagination.Result[*CurrencyCostBasis], error) { diff --git a/openmeter/ent/db/predicate/predicate.go b/openmeter/ent/db/predicate/predicate.go index da7d473b85..ae99e8e887 100644 --- a/openmeter/ent/db/predicate/predicate.go +++ b/openmeter/ent/db/predicate/predicate.go @@ -255,6 +255,12 @@ type ChargeUsageBasedRuns func(*sql.Selector) // ChargesSearchV1 is the predicate function for chargessearchv1 builders. type ChargesSearchV1 func(*sql.Selector) +// CreditRealizationLineage is the predicate function for creditrealizationlineage builders. +type CreditRealizationLineage func(*sql.Selector) + +// CreditRealizationLineageSegment is the predicate function for creditrealizationlineagesegment builders. +type CreditRealizationLineageSegment func(*sql.Selector) + // CurrencyCostBasis is the predicate function for currencycostbasis builders. type CurrencyCostBasis func(*sql.Selector) diff --git a/openmeter/ent/db/runtime.go b/openmeter/ent/db/runtime.go index 25d1738e73..0d1ddf2b6d 100644 --- a/openmeter/ent/db/runtime.go +++ b/openmeter/ent/db/runtime.go @@ -47,6 +47,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruninvoicedusage" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedrunpayment" "github.com/openmeterio/openmeter/openmeter/ent/db/chargeusagebasedruns" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineage" + "github.com/openmeterio/openmeter/openmeter/ent/db/creditrealizationlineagesegment" "github.com/openmeterio/openmeter/openmeter/ent/db/currencycostbasis" "github.com/openmeterio/openmeter/openmeter/ent/db/customcurrency" "github.com/openmeterio/openmeter/openmeter/ent/db/customer" @@ -1373,6 +1375,62 @@ func init() { chargessearchv1.DefaultUpdatedAt = chargessearchv1DescUpdatedAt.Default.(func() time.Time) // chargessearchv1.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. chargessearchv1.UpdateDefaultUpdatedAt = chargessearchv1DescUpdatedAt.UpdateDefault.(func() time.Time) + creditrealizationlineageMixin := schema.CreditRealizationLineage{}.Mixin() + creditrealizationlineageMixinFields0 := creditrealizationlineageMixin[0].Fields() + _ = creditrealizationlineageMixinFields0 + creditrealizationlineageMixinFields1 := creditrealizationlineageMixin[1].Fields() + _ = creditrealizationlineageMixinFields1 + creditrealizationlineageFields := schema.CreditRealizationLineage{}.Fields() + _ = creditrealizationlineageFields + // creditrealizationlineageDescNamespace is the schema descriptor for namespace field. + creditrealizationlineageDescNamespace := creditrealizationlineageMixinFields1[0].Descriptor() + // creditrealizationlineage.NamespaceValidator is a validator for the "namespace" field. It is called by the builders before save. + creditrealizationlineage.NamespaceValidator = creditrealizationlineageDescNamespace.Validators[0].(func(string) error) + // creditrealizationlineageDescChargeID is the schema descriptor for charge_id field. + creditrealizationlineageDescChargeID := creditrealizationlineageFields[0].Descriptor() + // creditrealizationlineage.ChargeIDValidator is a validator for the "charge_id" field. It is called by the builders before save. + creditrealizationlineage.ChargeIDValidator = creditrealizationlineageDescChargeID.Validators[0].(func(string) error) + // creditrealizationlineageDescRootRealizationID is the schema descriptor for root_realization_id field. + creditrealizationlineageDescRootRealizationID := creditrealizationlineageFields[1].Descriptor() + // creditrealizationlineage.RootRealizationIDValidator is a validator for the "root_realization_id" field. It is called by the builders before save. + creditrealizationlineage.RootRealizationIDValidator = creditrealizationlineageDescRootRealizationID.Validators[0].(func(string) error) + // creditrealizationlineageDescCustomerID is the schema descriptor for customer_id field. + creditrealizationlineageDescCustomerID := creditrealizationlineageFields[2].Descriptor() + // creditrealizationlineage.CustomerIDValidator is a validator for the "customer_id" field. It is called by the builders before save. + creditrealizationlineage.CustomerIDValidator = creditrealizationlineageDescCustomerID.Validators[0].(func(string) error) + // creditrealizationlineageDescCurrency is the schema descriptor for currency field. + creditrealizationlineageDescCurrency := creditrealizationlineageFields[3].Descriptor() + // creditrealizationlineage.CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. + creditrealizationlineage.CurrencyValidator = creditrealizationlineageDescCurrency.Validators[0].(func(string) error) + // creditrealizationlineageDescCreatedAt is the schema descriptor for created_at field. + creditrealizationlineageDescCreatedAt := creditrealizationlineageFields[5].Descriptor() + // creditrealizationlineage.DefaultCreatedAt holds the default value on creation for the created_at field. + creditrealizationlineage.DefaultCreatedAt = creditrealizationlineageDescCreatedAt.Default.(func() time.Time) + // creditrealizationlineageDescID is the schema descriptor for id field. + creditrealizationlineageDescID := creditrealizationlineageMixinFields0[0].Descriptor() + // creditrealizationlineage.DefaultID holds the default value on creation for the id field. + creditrealizationlineage.DefaultID = creditrealizationlineageDescID.Default.(func() string) + creditrealizationlineagesegmentMixin := schema.CreditRealizationLineageSegment{}.Mixin() + creditrealizationlineagesegmentMixinFields0 := creditrealizationlineagesegmentMixin[0].Fields() + _ = creditrealizationlineagesegmentMixinFields0 + creditrealizationlineagesegmentFields := schema.CreditRealizationLineageSegment{}.Fields() + _ = creditrealizationlineagesegmentFields + // creditrealizationlineagesegmentDescLineageID is the schema descriptor for lineage_id field. + creditrealizationlineagesegmentDescLineageID := creditrealizationlineagesegmentFields[0].Descriptor() + // creditrealizationlineagesegment.LineageIDValidator is a validator for the "lineage_id" field. It is called by the builders before save. + creditrealizationlineagesegment.LineageIDValidator = creditrealizationlineagesegmentDescLineageID.Validators[0].(func(string) error) + // creditrealizationlineagesegmentDescBackingTransactionGroupID is the schema descriptor for backing_transaction_group_id field. + creditrealizationlineagesegmentDescBackingTransactionGroupID := creditrealizationlineagesegmentFields[3].Descriptor() + // creditrealizationlineagesegment.BackingTransactionGroupIDValidator is a validator for the "backing_transaction_group_id" field. It is called by the builders before save. + creditrealizationlineagesegment.BackingTransactionGroupIDValidator = creditrealizationlineagesegmentDescBackingTransactionGroupID.Validators[0].(func(string) error) + // creditrealizationlineagesegmentDescCreatedAt is the schema descriptor for created_at field. + creditrealizationlineagesegmentDescCreatedAt := creditrealizationlineagesegmentFields[5].Descriptor() + // creditrealizationlineagesegment.DefaultCreatedAt holds the default value on creation for the created_at field. + creditrealizationlineagesegment.DefaultCreatedAt = creditrealizationlineagesegmentDescCreatedAt.Default.(func() time.Time) + // creditrealizationlineagesegmentDescID is the schema descriptor for id field. + creditrealizationlineagesegmentDescID := creditrealizationlineagesegmentMixinFields0[0].Descriptor() + // creditrealizationlineagesegment.DefaultID holds the default value on creation for the id field. + creditrealizationlineagesegment.DefaultID = creditrealizationlineagesegmentDescID.Default.(func() string) currencycostbasisMixin := schema.CurrencyCostBasis{}.Mixin() currencycostbasisMixinFields0 := currencycostbasisMixin[0].Fields() _ = currencycostbasisMixinFields0 diff --git a/openmeter/ent/db/setorclear.go b/openmeter/ent/db/setorclear.go index 01ef8dc2be..c781249b2d 100644 --- a/openmeter/ent/db/setorclear.go +++ b/openmeter/ent/db/setorclear.go @@ -3181,6 +3181,34 @@ func (u *ChargeUsageBasedRunsUpdateOne) SetOrClearDeletedAt(value *time.Time) *C return u.SetDeletedAt(*value) } +func (u *CreditRealizationLineageSegmentUpdate) SetOrClearBackingTransactionGroupID(value *string) *CreditRealizationLineageSegmentUpdate { + if value == nil { + return u.ClearBackingTransactionGroupID() + } + return u.SetBackingTransactionGroupID(*value) +} + +func (u *CreditRealizationLineageSegmentUpdateOne) SetOrClearBackingTransactionGroupID(value *string) *CreditRealizationLineageSegmentUpdateOne { + if value == nil { + return u.ClearBackingTransactionGroupID() + } + return u.SetBackingTransactionGroupID(*value) +} + +func (u *CreditRealizationLineageSegmentUpdate) SetOrClearClosedAt(value *time.Time) *CreditRealizationLineageSegmentUpdate { + if value == nil { + return u.ClearClosedAt() + } + return u.SetClosedAt(*value) +} + +func (u *CreditRealizationLineageSegmentUpdateOne) SetOrClearClosedAt(value *time.Time) *CreditRealizationLineageSegmentUpdateOne { + if value == nil { + return u.ClearClosedAt() + } + return u.SetClosedAt(*value) +} + func (u *CurrencyCostBasisUpdate) SetOrClearDeletedAt(value *time.Time) *CurrencyCostBasisUpdate { if value == nil { return u.ClearDeletedAt() diff --git a/openmeter/ent/db/tx.go b/openmeter/ent/db/tx.go index 73fe73006b..02f2c8ca60 100644 --- a/openmeter/ent/db/tx.go +++ b/openmeter/ent/db/tx.go @@ -92,6 +92,10 @@ type Tx struct { ChargeUsageBasedRuns *ChargeUsageBasedRunsClient // ChargesSearchV1 is the client for interacting with the ChargesSearchV1 builders. ChargesSearchV1 *ChargesSearchV1Client + // CreditRealizationLineage is the client for interacting with the CreditRealizationLineage builders. + CreditRealizationLineage *CreditRealizationLineageClient + // CreditRealizationLineageSegment is the client for interacting with the CreditRealizationLineageSegment builders. + CreditRealizationLineageSegment *CreditRealizationLineageSegmentClient // CurrencyCostBasis is the client for interacting with the CurrencyCostBasis builders. CurrencyCostBasis *CurrencyCostBasisClient // CustomCurrency is the client for interacting with the CustomCurrency builders. @@ -328,6 +332,8 @@ func (tx *Tx) init() { tx.ChargeUsageBasedRunPayment = NewChargeUsageBasedRunPaymentClient(tx.config) tx.ChargeUsageBasedRuns = NewChargeUsageBasedRunsClient(tx.config) tx.ChargesSearchV1 = NewChargesSearchV1Client(tx.config) + tx.CreditRealizationLineage = NewCreditRealizationLineageClient(tx.config) + tx.CreditRealizationLineageSegment = NewCreditRealizationLineageSegmentClient(tx.config) tx.CurrencyCostBasis = NewCurrencyCostBasisClient(tx.config) tx.CustomCurrency = NewCustomCurrencyClient(tx.config) tx.Customer = NewCustomerClient(tx.config) diff --git a/openmeter/ent/schema/charges.go b/openmeter/ent/schema/charges.go index 61bbd721e6..4d3e4de564 100644 --- a/openmeter/ent/schema/charges.go +++ b/openmeter/ent/schema/charges.go @@ -156,6 +156,7 @@ func (Charge) Edges() []ent.Edge { // Billing edge.To("billing_invoice_lines", BillingInvoiceLine.Type), edge.To("billing_split_line_groups", BillingInvoiceSplitLineGroup.Type), + edge.To("credit_realization_lineages", CreditRealizationLineage.Type), } } diff --git a/openmeter/ent/schema/creditrealizationlineage.go b/openmeter/ent/schema/creditrealizationlineage.go new file mode 100644 index 0000000000..822ab9fa54 --- /dev/null +++ b/openmeter/ent/schema/creditrealizationlineage.go @@ -0,0 +1,147 @@ +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/entutils" +) + +func creditRealizationLineageNow() time.Time { + return clock.Now().Truncate(time.Microsecond) +} + +type CreditRealizationLineage struct { + ent.Schema +} + +func (CreditRealizationLineage) Mixin() []ent.Mixin { + return []ent.Mixin{ + entutils.IDMixin{}, + entutils.NamespaceMixin{}, + } +} + +func (CreditRealizationLineage) Fields() []ent.Field { + return []ent.Field{ + field.String("charge_id"). + SchemaType(map[string]string{ + dialect.Postgres: "char(26)", + }). + NotEmpty(). + Immutable(), + field.String("root_realization_id"). + SchemaType(map[string]string{ + dialect.Postgres: "char(26)", + }). + NotEmpty(). + Immutable(), + field.String("customer_id"). + SchemaType(map[string]string{ + dialect.Postgres: "char(26)", + }). + NotEmpty(). + Immutable(), + field.String("currency"). + GoType(currencyx.Code("")). + NotEmpty(). + Immutable(). + SchemaType(map[string]string{ + dialect.Postgres: "varchar(3)", + }), + field.Enum("origin_kind"). + GoType(creditrealization.LineageOriginKind("")). + Immutable(), + field.Time("created_at"). + Default(creditRealizationLineageNow). + Immutable(), + } +} + +func (CreditRealizationLineage) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("charge", Charge.Type). + Field("charge_id"). + Ref("credit_realization_lineages"). + Required(). + Unique(). + Immutable(), + edge.To("segments", CreditRealizationLineageSegment.Type), + } +} + +func (CreditRealizationLineage) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("namespace", "root_realization_id").Unique(), + index.Fields("namespace", "charge_id"), + index.Fields("namespace", "customer_id"), + } +} + +type CreditRealizationLineageSegment struct { + ent.Schema +} + +func (CreditRealizationLineageSegment) Mixin() []ent.Mixin { + return []ent.Mixin{ + entutils.IDMixin{}, + } +} + +func (CreditRealizationLineageSegment) Fields() []ent.Field { + return []ent.Field{ + field.String("lineage_id"). + SchemaType(map[string]string{ + dialect.Postgres: "char(26)", + }). + NotEmpty(). + Immutable(), + field.Other("amount", alpacadecimal.Decimal{}). + SchemaType(map[string]string{ + dialect.Postgres: "numeric", + }), + field.Enum("state"). + GoType(creditrealization.LineageSegmentState("")). + Immutable(), + field.String("backing_transaction_group_id"). + SchemaType(map[string]string{ + dialect.Postgres: "char(26)", + }). + Optional(). + NotEmpty(). + Nillable(), + field.Time("closed_at"). + Optional(). + Nillable(), + field.Time("created_at"). + Default(creditRealizationLineageNow). + Immutable(), + } +} + +func (CreditRealizationLineageSegment) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("lineage", CreditRealizationLineage.Type). + Ref("segments"). + Field("lineage_id"). + Required(). + Unique(). + Immutable(), + } +} + +func (CreditRealizationLineageSegment) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("lineage_id"), + index.Fields("lineage_id", "closed_at"), + } +} diff --git a/openmeter/ledger/annotations.go b/openmeter/ledger/annotations.go index 9526a0569f..fc93b6e81b 100644 --- a/openmeter/ledger/annotations.go +++ b/openmeter/ledger/annotations.go @@ -1,10 +1,24 @@ package ledger -import "github.com/openmeterio/openmeter/pkg/models" +import ( + "fmt" + + "github.com/openmeterio/openmeter/pkg/models" +) const ( AnnotationChargeNamespace = "ledger.charge.namespace" AnnotationChargeID = "ledger.charge.id" + + AnnotationTransactionTemplateName = "ledger.transaction.template_name" + AnnotationTransactionDirection = "ledger.transaction.direction" +) + +type TransactionDirection string + +const ( + TransactionDirectionForward TransactionDirection = "forward" + TransactionDirectionCorrection TransactionDirection = "correction" ) func ChargeAnnotations(chargeID models.NamespacedID) models.Annotations { @@ -13,3 +27,44 @@ func ChargeAnnotations(chargeID models.NamespacedID) models.Annotations { AnnotationChargeID: chargeID.ID, } } + +func TransactionAnnotations(templateName string, direction TransactionDirection) models.Annotations { + return models.Annotations{ + AnnotationTransactionTemplateName: templateName, + AnnotationTransactionDirection: string(direction), + } +} + +func TransactionTemplateNameFromAnnotations(annotations models.Annotations) (string, error) { + raw, ok := annotations[AnnotationTransactionTemplateName] + if !ok { + return "", fmt.Errorf("transaction template name annotation is required") + } + + name, ok := raw.(string) + if !ok || name == "" { + return "", fmt.Errorf("transaction template name annotation is invalid") + } + + return name, nil +} + +func TransactionDirectionFromAnnotations(annotations models.Annotations) (TransactionDirection, error) { + raw, ok := annotations[AnnotationTransactionDirection] + if !ok { + return "", fmt.Errorf("transaction direction annotation is required") + } + + value, ok := raw.(string) + if !ok || value == "" { + return "", fmt.Errorf("transaction direction annotation is invalid") + } + + direction := TransactionDirection(value) + switch direction { + case TransactionDirectionForward, TransactionDirectionCorrection: + return direction, nil + default: + return "", fmt.Errorf("invalid transaction direction annotation %q", value) + } +} diff --git a/openmeter/ledger/chargeadapter/creditpurchase.go b/openmeter/ledger/chargeadapter/creditpurchase.go index 05dfbec14d..d4d38b3891 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase.go +++ b/openmeter/ledger/chargeadapter/creditpurchase.go @@ -198,7 +198,7 @@ func (h *creditPurchaseHandler) issueCreditPurchase(ctx context.Context, charge issuableAmount = alpacadecimal.Zero } - var templates []transactions.Resolver + var templates []transactions.TransactionTemplate if advanceAttributionAmount.IsPositive() { templates = append(templates, transactions.AttributeCustomerAdvanceReceivableCostBasisTemplate{ diff --git a/openmeter/ledger/chargeadapter/creditpurchase_test.go b/openmeter/ledger/chargeadapter/creditpurchase_test.go index e19b052b91..c3df7d90a1 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase_test.go +++ b/openmeter/ledger/chargeadapter/creditpurchase_test.go @@ -336,7 +336,7 @@ func (e *creditPurchaseHandlerTestEnv) createAdvanceExposure(t *testing.T, amoun Amount: amount, Currency: e.Currency, }, - transactions.TransferCustomerFBOBucketToAccruedTemplate{ + transactions.TransferCustomerFBOAdvanceToAccruedTemplate{ At: e.Now(), Amount: amount, Currency: e.Currency, diff --git a/openmeter/ledger/chargeadapter/flatfee.go b/openmeter/ledger/chargeadapter/flatfee.go index 28577715a9..580fc40708 100644 --- a/openmeter/ledger/chargeadapter/flatfee.go +++ b/openmeter/ledger/chargeadapter/flatfee.go @@ -5,37 +5,37 @@ import ( "fmt" "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ledger" - ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + "github.com/openmeterio/openmeter/openmeter/ledger/collector" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/pkg/models" - "github.com/openmeterio/openmeter/pkg/timeutil" ) // flatFeeHandler maps charge lifecycle events to ledger transaction templates type flatFeeHandler struct { - ledger ledger.Ledger - accountResolver ledger.AccountResolver - accountService ledgeraccount.Service + ledger ledger.Ledger + deps transactions.ResolverDependencies + collector collector.Service } var _ flatfee.Handler = (*flatFeeHandler)(nil) func NewFlatFeeHandler( ledger ledger.Ledger, - accountResolver ledger.AccountResolver, - accountService ledgeraccount.Service, + deps transactions.ResolverDependencies, + collectorService collector.Service, ) flatfee.Handler { return &flatFeeHandler{ - ledger: ledger, - accountResolver: accountResolver, - accountService: accountService, + ledger: ledger, + deps: deps, + collector: collectorService, } } @@ -63,15 +63,24 @@ func (h *flatFeeHandler) OnAssignedToInvoice(ctx context.Context, input flatfee. return nil, nil } - groupID, inputs, err := h.allocateCreditsToAccrued(ctx, input.Charge, input.PreTaxTotalAmount) + realizations, err := h.collector.CollectToAccrued(ctx, collector.CollectToAccruedInput{ + Namespace: input.Charge.Namespace, + ChargeID: input.Charge.ID, + CustomerID: input.Charge.Intent.CustomerID, + At: input.Charge.Intent.InvoiceAt, + Currency: input.Charge.Intent.Currency, + SettlementMode: input.Charge.Intent.SettlementMode, + ServicePeriod: input.ServicePeriod, + Amount: input.PreTaxTotalAmount, + }) if err != nil { return nil, err } - if groupID == "" { + if len(realizations) == 0 { return nil, nil } - return creditRealizationsFromCollectedInputs(input.ServicePeriod, groupID, inputs...), nil + return realizations, nil } // OnFlatFeeStandardInvoiceUsageAccrued handles the portion of usage not covered by FBO credits. @@ -106,7 +115,7 @@ func (h *flatFeeHandler) OnInvoiceUsageAccrued(ctx context.Context, input flatfe inputs, err := transactions.ResolveTransactions( ctx, - h.resolverDependencies(), + h.deps, transactions.ResolutionScope{ CustomerID: customerID, Namespace: input.Charge.Namespace, @@ -115,7 +124,7 @@ func (h *flatFeeHandler) OnInvoiceUsageAccrued(ctx context.Context, input flatfe At: input.Charge.Intent.InvoiceAt, Amount: amount, Currency: input.Charge.Intent.Currency, - CostBasis: invoiceCostBasis(), + CostBasis: invoiceCostBasis, }, ) if err != nil { @@ -151,48 +160,59 @@ func (h *flatFeeHandler) OnCreditsOnlyUsageAccrued(ctx context.Context, input fl return nil, fmt.Errorf("credits only usage accrued: %w", err) } - groupID, inputs, err := h.allocateCreditsToAccrued(ctx, input.Charge, input.AmountToAllocate) + realizations, err := h.collector.CollectToAccrued(ctx, collector.CollectToAccruedInput{ + Namespace: input.Charge.Namespace, + ChargeID: input.Charge.ID, + CustomerID: input.Charge.Intent.CustomerID, + At: input.Charge.Intent.InvoiceAt, + Currency: input.Charge.Intent.Currency, + SettlementMode: input.Charge.Intent.SettlementMode, + ServicePeriod: input.Charge.Intent.ServicePeriod, + Amount: input.AmountToAllocate, + }) if err != nil { return nil, err } - if groupID == "" { + if len(realizations) == 0 { return nil, nil } - return creditRealizationsFromCollectedInputs(input.Charge.Intent.ServicePeriod, groupID, inputs...), nil + return realizations, nil } func (h *flatFeeHandler) OnCreditsOnlyUsageAccruedCorrection(ctx context.Context, input flatfee.CreditsOnlyUsageAccruedCorrectionInput) (creditrealization.CreateCorrectionInputs, error) { - if err := input.Validate(); err != nil { + currencyCalculator, err := input.Charge.Intent.Currency.Calculator() + if err != nil { + return nil, fmt.Errorf("get currency calculator: %w", err) + } + + if err := input.ValidateWith(currencyCalculator); err != nil { return nil, err } - return nil, fmt.Errorf("credits only usage accrued correction is not implemented") + return h.collector.CorrectCollectedAccrued(ctx, collector.CorrectCollectedAccruedInput{ + Namespace: input.Charge.Namespace, + ChargeID: input.Charge.ID, + CustomerID: input.Charge.Intent.CustomerID, + AllocateAt: input.AllocateAt, + Corrections: input.Corrections, + LineageSegmentsByRealization: input.LineageSegmentsByRealization, + }) } -// OnFlatFeePaymentAuthorized is the current revenue recognition point. -// It replenishes receivable from wash for the directly-invoiced portion, and -// recognizes revenue by moving from customer_accrued to earnings. +// OnFlatFeePaymentAuthorized currently only stages receivable funding from wash +// for the directly-invoiced portion. Revenue recognition is handled elsewhere. func (h *flatFeeHandler) OnPaymentAuthorized(ctx context.Context, charge flatfee.Charge) (ledgertransaction.GroupReference, error) { if err := charge.Validate(); err != nil { return ledgertransaction.GroupReference{}, err } - // Compute the total amount to recognize from accrued into earnings. - // This includes both credit-backed (FBO) and receivable-backed portions. - totalRecognition := alpacadecimal.NewFromInt(0) - for _, cr := range charge.State.CreditRealizations { - totalRecognition = totalRecognition.Add(cr.Amount) - } - - // The receivable portion needs wash -> receivable replenishment. receivableReplenishment := alpacadecimal.NewFromInt(0) if charge.State.AccruedUsage != nil { receivableReplenishment = charge.State.AccruedUsage.Totals.Total - totalRecognition = totalRecognition.Add(charge.State.AccruedUsage.Totals.Total) } - if totalRecognition.IsZero() { + if receivableReplenishment.IsZero() { return ledgertransaction.GroupReference{}, nil } @@ -205,35 +225,19 @@ func (h *flatFeeHandler) OnPaymentAuthorized(ctx context.Context, charge flatfee ID: charge.ID, }) - var templates []transactions.Resolver - if receivableReplenishment.IsPositive() { - templates = append(templates, transactions.FundCustomerReceivableTemplate{ - At: charge.Intent.InvoiceAt, - Amount: receivableReplenishment, - Currency: charge.Intent.Currency, - CostBasis: invoiceCostBasis(), - }) - } - if totalRecognition.IsPositive() { - templates = append(templates, transactions.RecognizeEarningsFromAttributableAccruedTemplate{ - At: charge.Intent.InvoiceAt, - Amount: totalRecognition, - Currency: charge.Intent.Currency, - }) - } - - if len(templates) == 0 { - return ledgertransaction.GroupReference{}, nil - } - inputs, err := transactions.ResolveTransactions( ctx, - h.resolverDependencies(), + h.deps, transactions.ResolutionScope{ CustomerID: customerID, Namespace: charge.Namespace, }, - templates..., + transactions.FundCustomerReceivableTemplate{ + At: charge.Intent.InvoiceAt, + Amount: receivableReplenishment, + Currency: charge.Intent.Currency, + CostBasis: invoiceCostBasis, + }, ) if err != nil { return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) @@ -273,7 +277,7 @@ func (h *flatFeeHandler) OnPaymentSettled(ctx context.Context, charge flatfee.Ch inputs, err := transactions.ResolveTransactions( ctx, - h.resolverDependencies(), + h.deps, transactions.ResolutionScope{ CustomerID: customerID, Namespace: charge.Namespace, @@ -282,7 +286,7 @@ func (h *flatFeeHandler) OnPaymentSettled(ctx context.Context, charge flatfee.Ch At: charge.Intent.InvoiceAt, Amount: charge.State.AccruedUsage.Totals.Total, Currency: charge.Intent.Currency, - CostBasis: invoiceCostBasis(), + CostBasis: invoiceCostBasis, }, ) if err != nil { @@ -309,13 +313,6 @@ func (h *flatFeeHandler) OnPaymentUncollectible(_ context.Context, _ flatfee.Cha return ledgertransaction.GroupReference{}, fmt.Errorf("flat fee uncollectible write-off is not yet implemented") } -func (h *flatFeeHandler) resolverDependencies() transactions.ResolverDependencies { - return transactions.ResolverDependencies{ - AccountService: h.accountResolver, - SubAccountService: h.accountService, - } -} - func validateSettlementMode(actual productcatalog.SettlementMode, allowed ...productcatalog.SettlementMode) error { for _, candidate := range allowed { if actual == candidate { @@ -326,125 +323,4 @@ func validateSettlementMode(actual productcatalog.SettlementMode, allowed ...pro return fmt.Errorf("unsupported settlement mode %q", actual) } -func (h *flatFeeHandler) allocateCreditsToAccrued(ctx context.Context, charge flatfee.Charge, amount alpacadecimal.Decimal) (string, []ledger.TransactionInput, error) { - customerID := customer.CustomerID{ - Namespace: charge.Namespace, - ID: charge.Intent.CustomerID, - } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) - - inputs, err := transactions.ResolveTransactions( - ctx, - h.resolverDependencies(), - transactions.ResolutionScope{ - CustomerID: customerID, - Namespace: charge.Namespace, - }, - transactions.TransferCustomerFBOToAccruedTemplate{ - At: charge.Intent.InvoiceAt, - Amount: amount, - Currency: charge.Intent.Currency, - }, - ) - if err != nil { - return "", nil, fmt.Errorf("resolve transactions: %w", err) - } - - collectedAmount := sumCollectedFBOAmount(inputs...) - shortfall := amount.Sub(collectedAmount) - if charge.Intent.SettlementMode == productcatalog.CreditOnlySettlementMode && shortfall.IsPositive() { - advanceInputs, err := transactions.ResolveTransactions( - ctx, - h.resolverDependencies(), - transactions.ResolutionScope{ - CustomerID: customerID, - Namespace: charge.Namespace, - }, - transactions.IssueCustomerReceivableTemplate{ - At: charge.Intent.InvoiceAt, - Amount: shortfall, - Currency: charge.Intent.Currency, - }, - transactions.TransferCustomerFBOBucketToAccruedTemplate{ - At: charge.Intent.InvoiceAt, - Amount: shortfall, - Currency: charge.Intent.Currency, - }, - ) - if err != nil { - return "", nil, fmt.Errorf("resolve advance transactions: %w", err) - } - - inputs = append(inputs, advanceInputs...) - } - - if len(inputs) == 0 { - return "", nil, nil - } - - transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( - charge.Namespace, - annotations, - inputs..., - )) - if err != nil { - return "", nil, fmt.Errorf("commit ledger transaction group: %w", err) - } - - return transactionGroup.ID().ID, inputs, nil -} - -func creditRealizationsFromCollectedInputs(servicePeriod timeutil.ClosedPeriod, transactionGroupID string, inputs ...ledger.TransactionInput) creditrealization.CreateAllocationInputs { - out := make(creditrealization.CreateAllocationInputs, 0, len(inputs)) - for _, input := range inputs { - if input == nil { - continue - } - for _, entry := range input.EntryInputs() { - if entry.Amount().IsNegative() && entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO { - out = append(out, creditrealization.CreateAllocationInput{ - ServicePeriod: servicePeriod, - Amount: entry.Amount().Abs(), - LedgerTransaction: ledgertransaction.GroupReference{ - TransactionGroupID: transactionGroupID, - }, - }) - } - } - } - - return out -} - -func invoiceCostBasis() *alpacadecimal.Decimal { - value := alpacadecimal.NewFromInt(1) - return &value -} - -func sumCollectedFBOAmount(inputs ...ledger.TransactionInput) alpacadecimal.Decimal { - total := alpacadecimal.Zero - for _, input := range inputs { - if input == nil { - continue - } - for _, entry := range input.EntryInputs() { - if entry.Amount().IsNegative() && entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO { - total = total.Add(entry.Amount().Abs()) - } - } - } - - return total -} - -func settledBalanceForSubAccount(ctx context.Context, subAccount ledger.SubAccount) (alpacadecimal.Decimal, error) { - balance, err := subAccount.GetBalance(ctx) - if err != nil { - return alpacadecimal.Decimal{}, fmt.Errorf("get balance for sub-account %s: %w", subAccount.Address().SubAccountID(), err) - } - - return balance.Settled(), nil -} +var invoiceCostBasis = lo.ToPtr(alpacadecimal.NewFromInt(1)) diff --git a/openmeter/ledger/chargeadapter/flatfee_test.go b/openmeter/ledger/chargeadapter/flatfee_test.go index 5cf51195ec..d1f856efc9 100644 --- a/openmeter/ledger/chargeadapter/flatfee_test.go +++ b/openmeter/ledger/chargeadapter/flatfee_test.go @@ -19,6 +19,7 @@ import ( ledgertransactiongroupdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransactiongroup" "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" ledgertestutils "github.com/openmeterio/openmeter/openmeter/ledger/testutils" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog" @@ -162,6 +163,115 @@ func TestOnFlatFeeCreditsOnlyUsageAccrued(t *testing.T) { }) } +func TestOnFlatFeeCreditsOnlyUsageAccruedCorrection(t *testing.T) { + t.Run("credit_only reverses advance-backed accrual", func(t *testing.T) { + env := newFlatFeeHandlerTestEnv(t) + + charge := env.newCreditsOnlyCharge(alpacadecimal.NewFromInt(30)) + allocations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeflatfee.OnCreditsOnlyUsageAccruedInput{ + Charge: charge, + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, allocations, 1) + + chargeWithRealizations := env.newChargeWithCreditRealizationsAndAccruedUsage(allocations, alpacadecimal.Zero) + chargeWithRealizations.Intent.SettlementMode = productcatalog.CreditOnlySettlementMode + + currencyCalculator, err := chargeWithRealizations.Intent.Currency.Calculator() + require.NoError(t, err) + + correctionsRequest, err := chargeWithRealizations.State.CreditRealizations.CreateCorrectionRequest(alpacadecimal.NewFromInt(-30), currencyCalculator) + require.NoError(t, err) + + corrections, err := env.handler.OnCreditsOnlyUsageAccruedCorrection(t.Context(), chargeflatfee.CreditsOnlyUsageAccruedCorrectionInput{ + Charge: chargeWithRealizations, + AllocateAt: env.Now(), + Corrections: correctionsRequest, + }) + require.NoError(t, err) + require.Len(t, corrections, 1) + require.True(t, corrections[0].Amount.Equal(alpacadecimal.NewFromInt(-30))) + + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownFboSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) + }) + + t.Run("credit_only reverses partial funded accrual in reverse priority order", func(t *testing.T) { + env := newFlatFeeHandlerTestEnv(t) + + priorityTwo := env.fundPriority(t, 2, 20) + priorityOne := env.fundPriority(t, 1, 30) + + charge := env.newCreditsOnlyCharge(alpacadecimal.NewFromInt(50)) + allocations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeflatfee.OnCreditsOnlyUsageAccruedInput{ + Charge: charge, + AmountToAllocate: alpacadecimal.NewFromInt(50), + }) + require.NoError(t, err) + require.Len(t, allocations, 2) + + chargeWithRealizations := env.newChargeWithCreditRealizationsAndAccruedUsage(allocations, alpacadecimal.Zero) + chargeWithRealizations.Intent.SettlementMode = productcatalog.CreditOnlySettlementMode + + currencyCalculator, err := chargeWithRealizations.Intent.Currency.Calculator() + require.NoError(t, err) + + correctionsRequest, err := chargeWithRealizations.State.CreditRealizations.CreateCorrectionRequest(alpacadecimal.NewFromInt(-35), currencyCalculator) + require.NoError(t, err) + + corrections, err := env.handler.OnCreditsOnlyUsageAccruedCorrection(t.Context(), chargeflatfee.CreditsOnlyUsageAccruedCorrectionInput{ + Charge: chargeWithRealizations, + AllocateAt: env.Now(), + Corrections: correctionsRequest, + }) + require.NoError(t, err) + require.Len(t, corrections, 2) + + require.True(t, env.sumBalance(t, priorityOne).Equal(alpacadecimal.NewFromInt(15))) + require.True(t, env.sumBalance(t, priorityTwo).Equal(alpacadecimal.NewFromInt(20))) + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(15))) + }) + + t.Run("credit_only mixed funded and advance correction only unwinds the advance companion once", func(t *testing.T) { + env := newFlatFeeHandlerTestEnv(t) + + priorityOne := env.fundPriority(t, 1, 20) + + charge := env.newCreditsOnlyCharge(alpacadecimal.NewFromInt(30)) + allocations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeflatfee.OnCreditsOnlyUsageAccruedInput{ + Charge: charge, + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, allocations, 2) + + chargeWithRealizations := env.newChargeWithCreditRealizationsAndAccruedUsage(allocations, alpacadecimal.Zero) + chargeWithRealizations.Intent.SettlementMode = productcatalog.CreditOnlySettlementMode + + currencyCalculator, err := chargeWithRealizations.Intent.Currency.Calculator() + require.NoError(t, err) + + correctionsRequest, err := chargeWithRealizations.State.CreditRealizations.CreateCorrectionRequest(alpacadecimal.NewFromInt(-30), currencyCalculator) + require.NoError(t, err) + + corrections, err := env.handler.OnCreditsOnlyUsageAccruedCorrection(t.Context(), chargeflatfee.CreditsOnlyUsageAccruedCorrectionInput{ + Charge: chargeWithRealizations, + AllocateAt: env.Now(), + Corrections: correctionsRequest, + }) + require.NoError(t, err) + require.Len(t, corrections, 2) + + require.True(t, env.sumBalance(t, priorityOne).Equal(alpacadecimal.NewFromInt(20))) + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownFboSubAccount(t)).Equal(alpacadecimal.Zero)) + }) +} + func TestOnFlatFeeStandardInvoiceUsageAccrued(t *testing.T) { t.Run("credit_then_invoice books receivable-backed usage into accrued", func(t *testing.T) { env := newFlatFeeHandlerTestEnv(t) @@ -196,7 +306,7 @@ func TestOnFlatFeeStandardInvoiceUsageAccrued(t *testing.T) { } func TestOnFlatFeePaymentAuthorized(t *testing.T) { - t.Run("credit_then_invoice recognizes revenue from receivable-backed accrued", func(t *testing.T) { + t.Run("credit_then_invoice stages receivable funding from receivable-backed accrued", func(t *testing.T) { env := newFlatFeeHandlerTestEnv(t) // First accrue usage: receivable → accrued @@ -213,12 +323,12 @@ func TestOnFlatFeePaymentAuthorized(t *testing.T) { require.True(t, env.sumBalance(t, env.receivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(-75))) require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(75))) require.True(t, env.sumBalance(t, env.washSubAccount(t)).Equal(alpacadecimal.NewFromInt(-75))) - // Accrued drained, earnings recognized - require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(0))) - require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(75))) + // No revenue recognition happens here anymore. + require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(75))) + require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) }) - t.Run("credit_then_invoice recognizes revenue from mixed FBO and receivable", func(t *testing.T) { + t.Run("credit_then_invoice mixed FBO and receivable only stages receivable funding", func(t *testing.T) { env := newFlatFeeHandlerTestEnv(t) // Fund FBO with 40 @@ -248,14 +358,14 @@ func TestOnFlatFeePaymentAuthorized(t *testing.T) { // Receivable funding stays staged until settlement. require.True(t, env.sumBalance(t, env.receivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(-20))) require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(20))) - // Accrued fully drained, all 60 recognized as earnings - require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(0))) - require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(0))) - require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(40))) - require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(20))) + // Existing accrued balances stay untouched until a later recognition flow. + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(40))) + require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(20))) + require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) }) - t.Run("credit_then_invoice does not overdraw accrued during recognition", func(t *testing.T) { + t.Run("credit_then_invoice does not touch accrued during authorization", func(t *testing.T) { env := newFlatFeeHandlerTestEnv(t) _, err := env.handler.OnInvoiceUsageAccrued(t.Context(), env.newAccrualInput(alpacadecimal.NewFromInt(30))) @@ -269,11 +379,11 @@ func TestOnFlatFeePaymentAuthorized(t *testing.T) { require.True(t, env.sumBalance(t, env.receivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(-30))) require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(75))) require.True(t, env.sumBalance(t, env.washSubAccount(t)).Equal(alpacadecimal.NewFromInt(-75))) - require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(0))) - require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.invoiceAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.invoiceEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) }) - t.Run("credit_only recognizes attributable credit-backed accrued without receivable funding", func(t *testing.T) { + t.Run("credit_only authorization is a no-op without receivable funding", func(t *testing.T) { env := newFlatFeeHandlerTestEnv(t) priorityOne := env.fundPriority(t, 1, 40) @@ -290,10 +400,10 @@ func TestOnFlatFeePaymentAuthorized(t *testing.T) { ref, err := env.handler.OnPaymentAuthorized(t.Context(), charge) require.NoError(t, err) - require.NotEmpty(t, ref.TransactionGroupID) + require.Empty(t, ref.TransactionGroupID) - require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) - require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) require.True(t, env.sumBalance(t, priorityOne).Equal(alpacadecimal.NewFromInt(10))) require.True(t, env.sumBalance(t, env.receivableSubAccount(t)).Equal(alpacadecimal.Zero)) require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) @@ -350,8 +460,8 @@ func TestOnFlatFeePaymentSettled(t *testing.T) { ref, err := env.handler.OnPaymentSettled(t.Context(), charge) require.NoError(t, err) require.Empty(t, ref.TransactionGroupID) - require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) - require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + require.True(t, env.sumBalance(t, env.creditEarningsSubAccount(t)).Equal(alpacadecimal.Zero)) require.True(t, env.sumBalance(t, env.receivableSubAccount(t)).Equal(alpacadecimal.Zero)) require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) }) @@ -375,13 +485,21 @@ type flatFeeHandlerTestEnv struct { func newFlatFeeHandlerTestEnv(t *testing.T) *flatFeeHandlerTestEnv { base := ledgertestutils.NewIntegrationEnv(t, "chargeadapter-flatfee") + deps := transactions.ResolverDependencies{ + AccountService: base.Deps.ResolversService, + SubAccountService: base.Deps.AccountService, + } + collectorService := ledgercollector.NewService(ledgercollector.Config{ + Ledger: base.Deps.HistoricalLedger, + Dependencies: deps, + }) return &flatFeeHandlerTestEnv{ IntegrationEnv: base, handler: chargeadapter.NewFlatFeeHandler( base.Deps.HistoricalLedger, - base.Deps.ResolversService, - base.Deps.AccountService, + deps, + collectorService, ), } } @@ -590,7 +708,12 @@ func (e *flatFeeHandlerTestEnv) newChargeWithCreditRealizationsAndAccruedUsage(r NamespacedModel: models.NamespacedModel{ Namespace: e.Namespace, }, + ManagedModel: models.ManagedModel{ + CreatedAt: now, + UpdatedAt: now, + }, CreateInput: r, + SortHint: i, }) } diff --git a/openmeter/ledger/chargeadapter/helpers.go b/openmeter/ledger/chargeadapter/helpers.go new file mode 100644 index 0000000000..68f7ac984d --- /dev/null +++ b/openmeter/ledger/chargeadapter/helpers.go @@ -0,0 +1,19 @@ +package chargeadapter + +import ( + "context" + "fmt" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/ledger" +) + +func settledBalanceForSubAccount(ctx context.Context, subAccount ledger.SubAccount) (alpacadecimal.Decimal, error) { + balance, err := subAccount.GetBalance(ctx) + if err != nil { + return alpacadecimal.Decimal{}, fmt.Errorf("get balance for sub-account %s: %w", subAccount.Address().SubAccountID(), err) + } + + return balance.Settled(), nil +} diff --git a/openmeter/ledger/chargeadapter/usagebased.go b/openmeter/ledger/chargeadapter/usagebased.go new file mode 100644 index 0000000000..cc1ee2cac8 --- /dev/null +++ b/openmeter/ledger/chargeadapter/usagebased.go @@ -0,0 +1,93 @@ +package chargeadapter + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/ledger/collector" + "github.com/openmeterio/openmeter/openmeter/productcatalog" +) + +// usageBasedHandler maps usage-based credit-only lifecycle events to ledger transaction templates. +type usageBasedHandler struct { + collector collector.Service +} + +var _ usagebased.Handler = (*usageBasedHandler)(nil) + +func NewUsageBasedHandler(collectorService collector.Service) usagebased.Handler { + return &usageBasedHandler{ + collector: collectorService, + } +} + +func (h *usageBasedHandler) OnCreditsOnlyUsageAccrued(ctx context.Context, input usagebased.CreditsOnlyUsageAccruedInput) (creditrealization.CreateAllocationInputs, error) { + if err := input.Validate(); err != nil { + return nil, err + } + + if input.AmountToAllocate.IsZero() { + return nil, nil + } + + if err := validateSettlementMode(input.Charge.Intent.SettlementMode, productcatalog.CreditOnlySettlementMode); err != nil { + return nil, fmt.Errorf("credits only usage accrued: %w", err) + } + + realizations, err := h.collector.CollectToAccrued(ctx, collector.CollectToAccruedInput{ + Namespace: input.Charge.Namespace, + ChargeID: input.Charge.ID, + CustomerID: input.Charge.Intent.CustomerID, + At: input.AllocateAt, + Currency: input.Charge.Intent.Currency, + SettlementMode: input.Charge.Intent.SettlementMode, + ServicePeriod: input.Charge.Intent.ServicePeriod, + Amount: input.AmountToAllocate, + }) + if err != nil { + return nil, err + } + if len(realizations) == 0 { + return nil, nil + } + + return realizations, nil +} + +func (h *usageBasedHandler) OnCreditsOnlyUsageAccruedCorrection(ctx context.Context, input usagebased.CreditsOnlyUsageAccruedCorrectionInput) (creditrealization.CreateCorrectionInputs, error) { + if err := input.Charge.Validate(); err != nil { + return nil, fmt.Errorf("charge: %w", err) + } + + if err := input.Run.Validate(); err != nil { + return nil, fmt.Errorf("run: %w", err) + } + + if input.AllocateAt.IsZero() { + return nil, fmt.Errorf("allocate at is required") + } + + if err := validateSettlementMode(input.Charge.Intent.SettlementMode, productcatalog.CreditOnlySettlementMode); err != nil { + return nil, fmt.Errorf("credits only usage accrued correction: %w", err) + } + + currencyCalculator, err := input.Charge.Intent.Currency.Calculator() + if err != nil { + return nil, fmt.Errorf("get currency calculator: %w", err) + } + + if err := input.Corrections.ValidateWith(currencyCalculator); err != nil { + return nil, fmt.Errorf("corrections: %w", err) + } + + return h.collector.CorrectCollectedAccrued(ctx, collector.CorrectCollectedAccruedInput{ + Namespace: input.Charge.Namespace, + ChargeID: input.Charge.ID, + CustomerID: input.Charge.Intent.CustomerID, + AllocateAt: input.AllocateAt, + Corrections: input.Corrections, + LineageSegmentsByRealization: input.LineageSegmentsByRealization, + }) +} diff --git a/openmeter/ledger/chargeadapter/usagebased_test.go b/openmeter/ledger/chargeadapter/usagebased_test.go new file mode 100644 index 0000000000..3659bd794e --- /dev/null +++ b/openmeter/ledger/chargeadapter/usagebased_test.go @@ -0,0 +1,312 @@ +package chargeadapter_test + +import ( + "fmt" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + chargeusagebased "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/billing/models/totals" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" + ledgertestutils "github.com/openmeterio/openmeter/openmeter/ledger/testutils" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +func TestOnUsageBasedCreditsOnlyUsageAccrued(t *testing.T) { + t.Run("credit_only advances uncovered amount", func(t *testing.T) { + env := newUsageBasedHandlerTestEnv(t) + + realizations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedInput{ + Charge: env.newCreditsOnlyCharge(), + Run: env.newRun(), + AllocateAt: env.Now(), + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, realizations, 1) + require.True(t, realizations[0].Amount.Equal(alpacadecimal.NewFromInt(30))) + + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(-30))) + require.True(t, env.sumBalance(t, env.unknownFboSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(30))) + }) + + t.Run("credit_only collects from funded balances first", func(t *testing.T) { + env := newUsageBasedHandlerTestEnv(t) + + priorityOne := env.fundPriority(t, 1, 20) + + realizations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedInput{ + Charge: env.newCreditsOnlyCharge(), + Run: env.newRun(), + AllocateAt: env.Now(), + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, realizations, 2) + require.True(t, realizations[0].Amount.Equal(alpacadecimal.NewFromInt(20))) + require.True(t, realizations[1].Amount.Equal(alpacadecimal.NewFromInt(10))) + + require.True(t, env.sumBalance(t, priorityOne).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.NewFromInt(-10))) + require.True(t, env.sumBalance(t, env.creditAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(20))) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.NewFromInt(10))) + }) + + t.Run("zero amount is rejected by input validation", func(t *testing.T) { + env := newUsageBasedHandlerTestEnv(t) + + realizations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedInput{ + Charge: env.newCreditsOnlyCharge(), + Run: env.newRun(), + AllocateAt: env.Now(), + AmountToAllocate: alpacadecimal.Zero, + }) + require.Error(t, err) + require.Nil(t, realizations) + require.Contains(t, err.Error(), "amount to allocate must be positive") + }) +} + +func TestOnUsageBasedCreditsOnlyUsageAccruedCorrection(t *testing.T) { + t.Run("credit_only reverses advance-backed accrual", func(t *testing.T) { + env := newUsageBasedHandlerTestEnv(t) + + run := env.newRun() + allocations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedInput{ + Charge: env.newCreditsOnlyCharge(), + Run: run, + AllocateAt: env.Now(), + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, allocations, 1) + + run.CreditsAllocated = env.realizationsFromAllocations(allocations) + + currencyCalculator, err := env.Currency.Calculator() + require.NoError(t, err) + + correctionsRequest, err := run.CreditsAllocated.CreateCorrectionRequest(alpacadecimal.NewFromInt(-30), currencyCalculator) + require.NoError(t, err) + + corrections, err := env.handler.OnCreditsOnlyUsageAccruedCorrection(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedCorrectionInput{ + Charge: env.newCreditsOnlyCharge(), + Run: run, + AllocateAt: env.Now(), + Corrections: correctionsRequest, + }) + require.NoError(t, err) + require.Len(t, corrections, 1) + require.True(t, corrections[0].Amount.Equal(alpacadecimal.NewFromInt(-30))) + + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownFboSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) + }) +} + +type usageBasedHandlerTestEnv struct { + *ledgertestutils.IntegrationEnv + handler chargeusagebased.Handler +} + +func newUsageBasedHandlerTestEnv(t *testing.T) *usageBasedHandlerTestEnv { + base := ledgertestutils.NewIntegrationEnv(t, "chargeadapter-usagebased") + collectorService := ledgercollector.NewService(ledgercollector.Config{ + Ledger: base.Deps.HistoricalLedger, + Dependencies: transactions.ResolverDependencies{ + AccountService: base.Deps.ResolversService, + SubAccountService: base.Deps.AccountService, + }, + }) + + return &usageBasedHandlerTestEnv{ + IntegrationEnv: base, + handler: chargeadapter.NewUsageBasedHandler(collectorService), + } +} + +func (e *usageBasedHandlerTestEnv) newCreditsOnlyCharge() chargeusagebased.Charge { + now := time.Now().UTC() + featureID := "feature-api-requests" + servicePeriod := timeutil.ClosedPeriod{ + From: now.Add(-time.Hour), + To: now, + } + + return chargeusagebased.Charge{ + ChargeBase: chargeusagebased.ChargeBase{ + ManagedResource: meta.ManagedResource{ + NamespacedModel: models.NamespacedModel{ + Namespace: e.Namespace, + }, + ManagedModel: models.ManagedModel{ + CreatedAt: now, + UpdatedAt: now, + }, + ID: "usage-based-charge", + }, + Intent: chargeusagebased.Intent{ + Intent: meta.Intent{ + Name: "Usage based", + ManagedBy: billing.SystemManagedLine, + CustomerID: e.CustomerID.ID, + Currency: currencyx.Code("USD"), + ServicePeriod: servicePeriod, + BillingPeriod: servicePeriod, + }, + InvoiceAt: now, + SettlementMode: productcatalog.CreditOnlySettlementMode, + FeatureKey: "api_requests", + Price: *productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromInt(1)}), + }, + Status: chargeusagebased.StatusActiveFinalRealizationProcessing, + State: chargeusagebased.State{ + FeatureID: featureID, + }, + }, + } +} + +func (e *usageBasedHandlerTestEnv) newRun() chargeusagebased.RealizationRun { + now := time.Now().UTC() + featureID := "feature-api-requests" + + return chargeusagebased.RealizationRun{ + RealizationRunBase: chargeusagebased.RealizationRunBase{ + ID: chargeusagebased.RealizationRunID(models.NamespacedID{ + Namespace: e.Namespace, + ID: "run-1", + }), + ManagedModel: models.ManagedModel{ + CreatedAt: now, + UpdatedAt: now, + }, + Type: chargeusagebased.RealizationRunTypeFinalRealization, + AsOf: now, + CollectionEnd: now, + MeterValue: alpacadecimal.NewFromInt(30), + FeatureID: featureID, + Totals: totals.Totals{ + Amount: alpacadecimal.NewFromInt(30), + CreditsTotal: alpacadecimal.NewFromInt(30), + Total: alpacadecimal.Zero, + }, + }, + } +} + +func (e *usageBasedHandlerTestEnv) fundPriority(t *testing.T, priority int, amount int64) ledger.SubAccount { + t.Helper() + + costBasis := alpacadecimal.Zero + subAccount, err := e.CustomerAccounts.FBOAccount.GetSubAccountForRoute(t.Context(), ledger.CustomerFBORouteParams{ + Currency: e.Currency, + CostBasis: &costBasis, + CreditPriority: priority, + }) + require.NoError(t, err) + + inputs, err := transactions.ResolveTransactions( + t.Context(), + transactions.ResolverDependencies{ + AccountService: e.Deps.ResolversService, + SubAccountService: e.Deps.AccountService, + }, + transactions.ResolutionScope{ + CustomerID: e.CustomerID, + Namespace: e.Namespace, + }, + transactions.IssueCustomerReceivableTemplate{ + At: e.Now(), + Amount: alpacadecimal.NewFromInt(amount), + Currency: e.Currency, + CostBasis: &costBasis, + CreditPriority: &priority, + }, + transactions.FundCustomerReceivableTemplate{ + At: e.Now(), + Amount: alpacadecimal.NewFromInt(amount), + Currency: e.Currency, + CostBasis: &costBasis, + }, + transactions.SettleCustomerReceivablePaymentTemplate{ + At: e.Now(), + Amount: alpacadecimal.NewFromInt(amount), + Currency: e.Currency, + CostBasis: &costBasis, + }, + ) + require.NoError(t, err) + + _, err = e.Deps.HistoricalLedger.CommitGroup(t.Context(), transactions.GroupInputs( + e.Namespace, + nil, + inputs..., + )) + require.NoError(t, err) + + return subAccount +} + +func (e *usageBasedHandlerTestEnv) creditAccruedSubAccount(t *testing.T) ledger.SubAccount { + zeroCostBasis := alpacadecimal.Zero + return e.AccruedSubAccountWithCostBasis(t, &zeroCostBasis) +} + +func (e *usageBasedHandlerTestEnv) unknownAccruedSubAccount(t *testing.T) ledger.SubAccount { + return e.AccruedSubAccountWithCostBasis(t, nil) +} + +func (e *usageBasedHandlerTestEnv) unknownReceivableSubAccount(t *testing.T) ledger.SubAccount { + return e.ReceivableSubAccountWithCostBasis(t, nil) +} + +func (e *usageBasedHandlerTestEnv) unknownFboSubAccount(t *testing.T) ledger.SubAccount { + subAccount, err := e.CustomerAccounts.FBOAccount.GetSubAccountForRoute(t.Context(), ledger.CustomerFBORouteParams{ + Currency: e.Currency, + CreditPriority: ledger.DefaultCustomerFBOPriority, + }) + require.NoError(t, err) + + return subAccount +} + +func (e *usageBasedHandlerTestEnv) sumBalance(t *testing.T, subAccount ledger.SubAccount) alpacadecimal.Decimal { + return e.SumBalance(t, subAccount) +} + +func (e *usageBasedHandlerTestEnv) realizationsFromAllocations(allocations creditrealization.CreateAllocationInputs) creditrealization.Realizations { + now := time.Now().UTC() + + out := make(creditrealization.Realizations, 0, len(allocations)) + for i, allocation := range allocations.AsCreateInputs() { + allocation.ID = fmt.Sprintf("cr-%d", i) + out = append(out, creditrealization.Realization{ + NamespacedModel: models.NamespacedModel{ + Namespace: e.Namespace, + }, + ManagedModel: models.ManagedModel{ + CreatedAt: now, + UpdatedAt: now, + }, + CreateInput: allocation, + SortHint: i, + }) + } + + return out +} diff --git a/openmeter/ledger/collector/collect.go b/openmeter/ledger/collector/collect.go new file mode 100644 index 0000000000..e22a2ed1e2 --- /dev/null +++ b/openmeter/ledger/collector/collect.go @@ -0,0 +1,185 @@ +package collector + +import ( + "context" + "fmt" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +type accrualCollector struct { + ledger ledger.Ledger + deps transactions.ResolverDependencies +} + +type collectedInputs []ledger.TransactionInput + +func (c *accrualCollector) collect(ctx context.Context, input CollectToAccruedInput) (creditrealization.CreateAllocationInputs, error) { + if input.Amount.IsZero() { + return nil, nil + } + + inputs, err := c.resolveCollectedInputs(ctx, input, input.Amount) + if err != nil { + return nil, err + } + + // Credit-only: if the wallet didn't cover the full accrual, issue advance and + // move that slice through the advance-to-accrued path. + if shortfall := input.Amount.Sub(collectedInputs(inputs).collectedFBOAmount()); c.shouldAdvanceShortfall(input, shortfall) { + advanceInputs, err := c.resolveAdvanceInputs(ctx, input, shortfall) + if err != nil { + return nil, err + } + + inputs = append(inputs, advanceInputs...) + } + + if len(inputs) == 0 { + return nil, nil + } + + transactionGroup, err := c.ledger.CommitGroup(ctx, transactions.GroupInputs( + input.Namespace, + ledger.ChargeAnnotations(models.NamespacedID{ + Namespace: input.Namespace, + ID: input.ChargeID, + }), + inputs..., + )) + if err != nil { + return nil, fmt.Errorf("commit ledger transaction group: %w", err) + } + + return collectedInputs(inputs).toCreditRealizations(input.ServicePeriod, transactionGroup.ID().ID), nil +} + +func (c *accrualCollector) resolveCollectedInputs(ctx context.Context, input CollectToAccruedInput, amount alpacadecimal.Decimal) ([]ledger.TransactionInput, error) { + inputs, err := transactions.ResolveTransactions( + ctx, + c.deps, + c.resolutionScope(input), + transactions.TransferCustomerFBOToAccruedTemplate{ + At: input.At, + Amount: amount, + Currency: input.Currency, + }, + ) + if err != nil { + return nil, fmt.Errorf("resolve transactions: %w", err) + } + + return inputs, nil +} + +func (c *accrualCollector) resolveAdvanceInputs(ctx context.Context, input CollectToAccruedInput, amount alpacadecimal.Decimal) ([]ledger.TransactionInput, error) { + inputs, err := transactions.ResolveTransactions( + ctx, + c.deps, + c.resolutionScope(input), + transactions.IssueCustomerReceivableTemplate{ + At: input.At, + Amount: amount, + Currency: input.Currency, + }, + transactions.TransferCustomerFBOAdvanceToAccruedTemplate{ + At: input.At, + Amount: amount, + Currency: input.Currency, + }, + ) + if err != nil { + return nil, fmt.Errorf("resolve advance transactions: %w", err) + } + + return inputs, nil +} + +func (c *accrualCollector) shouldAdvanceShortfall(input CollectToAccruedInput, shortfall alpacadecimal.Decimal) bool { + return input.SettlementMode == productcatalog.CreditOnlySettlementMode && shortfall.IsPositive() +} + +func (c *accrualCollector) resolutionScope(input CollectToAccruedInput) transactions.ResolutionScope { + return transactions.ResolutionScope{ + CustomerID: customer.CustomerID{ + Namespace: input.Namespace, + ID: input.CustomerID, + }, + Namespace: input.Namespace, + } +} + +func (i collectedInputs) toCreditRealizations(servicePeriod timeutil.ClosedPeriod, transactionGroupID string) creditrealization.CreateAllocationInputs { + out := make(creditrealization.CreateAllocationInputs, 0, len(i)) + for _, input := range i { + if input == nil { + continue + } + + annotations := creditRealizationAnnotationsForCollectedInput(input) + // One realization row per FBO debit on the resolved inputs (the spend from balance). + for _, entry := range input.EntryInputs() { + if entry.Amount().IsNegative() && entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO { + out = append(out, creditrealization.CreateAllocationInput{ + Annotations: annotations, + ServicePeriod: servicePeriod, + Amount: entry.Amount().Abs(), + LedgerTransaction: ledgertransaction.GroupReference{ + TransactionGroupID: transactionGroupID, + }, + }) + } + } + } + + return out +} + +func (i collectedInputs) collectedFBOAmount() alpacadecimal.Decimal { + total := alpacadecimal.Zero + for _, input := range i { + if input == nil { + continue + } + for _, entry := range input.EntryInputs() { + if entry.Amount().IsNegative() && entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO { + total = total.Add(entry.Amount().Abs()) + } + } + } + + return total +} + +func creditRealizationAnnotationsForCollectedInput(input ledger.TransactionInput) models.Annotations { + templateName, err := ledger.TransactionTemplateNameFromAnnotations(input.Annotations()) + if err != nil { + return input.Annotations() + } + + var originKind creditrealization.LineageOriginKind + switch templateName { + case transactions.TemplateName(transactions.TransferCustomerFBOToAccruedTemplate{}): + originKind = creditrealization.LineageOriginKindRealCredit + case transactions.TemplateName(transactions.TransferCustomerFBOAdvanceToAccruedTemplate{}): + originKind = creditrealization.LineageOriginKindAdvance + default: + return input.Annotations() + } + + annotations, err := input.Annotations().Merge(creditrealization.LineageAnnotations(originKind)) + if err != nil { + return input.Annotations() + } + + return annotations +} diff --git a/openmeter/ledger/collector/correct.go b/openmeter/ledger/collector/correct.go new file mode 100644 index 0000000000..ffd33c440e --- /dev/null +++ b/openmeter/ledger/collector/correct.go @@ -0,0 +1,454 @@ +package collector + +import ( + "context" + "fmt" + "sort" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" +) + +type accrualCorrector struct { + ledger ledger.Ledger + deps transactions.ResolverDependencies +} + +// collectedSource is one logical “collection” in the group: the FBO→accrued +// forward tx, plus the receivable issue when that slice was advance-backed. +type collectedSource struct { + transaction ledger.Transaction + group ledger.TransactionGroup + advanceReceivableIssueTransaction ledger.Transaction +} + +type transactionCorrectionPlan struct { + transaction ledger.Transaction + group ledger.TransactionGroup + amount alpacadecimal.Decimal +} + +type plannedAction interface { + isPlannedAction() +} + +type plannedTransactionCorrection struct { + transaction ledger.Transaction + group ledger.TransactionGroup + amount alpacadecimal.Decimal +} + +func (plannedTransactionCorrection) isPlannedAction() {} + +// plannedDirectInputs are inputs we already resolved (e.g. reissue); they skip +// the merge-and-CorrectTransaction path below. +type plannedDirectInputs struct { + inputs []ledger.TransactionInput +} + +func (plannedDirectInputs) isPlannedAction() {} + +func (c *accrualCorrector) correct(ctx context.Context, input CorrectCollectedAccruedInput) (creditrealization.CreateCorrectionInputs, error) { + if len(input.Corrections) == 0 { + return nil, nil + } + + // Plan first, execute later, so we can merge overlapping corrections cleanly. + actions := make([]plannedAction, 0, len(input.Corrections)) + for _, correction := range input.Corrections { + correctionActions, err := c.planCorrection(ctx, input, correction) + if err != nil { + return nil, err + } + actions = append(actions, correctionActions...) + } + + resolvedInputs, err := c.resolvePlannedInputs(ctx, input, actions) + if err != nil { + return nil, err + } + if len(resolvedInputs) == 0 { + return nil, nil + } + + // Write the whole correction batch as one group and point every new correction + // realization at that group. + transactionGroup, err := c.ledger.CommitGroup(ctx, transactions.GroupInputs( + input.Namespace, + ledger.ChargeAnnotations(models.NamespacedID{ + Namespace: input.Namespace, + ID: input.ChargeID, + }), + resolvedInputs..., + )) + if err != nil { + return nil, fmt.Errorf("commit correction transaction group: %w", err) + } + + out := make(creditrealization.CreateCorrectionInputs, 0, len(input.Corrections)) + for _, correction := range input.Corrections { + out = append(out, creditrealization.CreateCorrectionInput{ + LedgerTransaction: ledgertransaction.GroupReference{ + TransactionGroupID: transactionGroup.ID().ID, + }, + Amount: correction.Amount, + CorrectsRealizationID: correction.Allocation.ID, + }) + } + + return out, nil +} + +func (c *accrualCorrector) planCorrection(ctx context.Context, input CorrectCollectedAccruedInput, correction creditrealization.CorrectionRequestItem) ([]plannedAction, error) { + originalGroup, err := c.originalGroup(ctx, input, correction) + if err != nil { + return nil, err + } + + // SortHint maps the realization back to the original collected source in the group. + source, err := c.collectedSourceBySortHint(originalGroup, correction.Allocation.SortHint) + if err != nil { + return nil, err + } + + // Older data may not have lineage yet, so fall back to first-order source correction. + segments := input.LineageSegmentsByRealization[correction.Allocation.ID] + if len(segments) == 0 { + return plannedSourceCorrectionActions(source, correction.Amount.Abs(), source.advanceReceivableIssueTransaction != nil), nil + } + + // Lineage tells us what this value looks like now, so consume that state first. + remaining := correction.Amount.Abs() + actions := make([]plannedAction, 0, len(segments)+2) + for _, segment := range sortCorrectionSegments(segments) { + if !remaining.IsPositive() { + break + } + + segmentAmount := minDecimal(segment.Amount, remaining) + if !segmentAmount.IsPositive() { + continue + } + + segmentActions, err := c.planSegmentCorrection(ctx, input, source, segment, segmentAmount) + if err != nil { + return nil, err + } + actions = append(actions, segmentActions...) + + remaining = remaining.Sub(segmentAmount) + } + + if remaining.IsPositive() { + return nil, fmt.Errorf("correction amount %s exceeds active lineage coverage for realization %s", correction.Amount.Abs().String(), correction.Allocation.ID) + } + + return actions, nil +} + +func (c *accrualCorrector) originalGroup(ctx context.Context, input CorrectCollectedAccruedInput, correction creditrealization.CorrectionRequestItem) (ledger.TransactionGroup, error) { + group, err := c.ledger.GetTransactionGroup(ctx, models.NamespacedID{ + Namespace: input.Namespace, + ID: correction.Allocation.LedgerTransaction.TransactionGroupID, + }) + if err != nil { + return nil, fmt.Errorf("get original transaction group %s: %w", correction.Allocation.LedgerTransaction.TransactionGroupID, err) + } + + return group, nil +} + +func (c *accrualCorrector) planSegmentCorrection(ctx context.Context, input CorrectCollectedAccruedInput, source collectedSource, segment lineage.Segment, amount alpacadecimal.Decimal) ([]plannedAction, error) { + // Each current segment state needs a slightly different unwind. + switch segment.State { + case creditrealization.LineageSegmentStateRealCredit: + return plannedSourceCorrectionActions(source, amount, false), nil + case creditrealization.LineageSegmentStateAdvanceUncovered: + return plannedSourceCorrectionActions(source, amount, true), nil + case creditrealization.LineageSegmentStateAdvanceBackfilled: + return c.planBackfilledAdvanceSegment(ctx, input, source, segment, amount) + default: + return nil, fmt.Errorf("unsupported active lineage segment state %s", segment.State) + } +} + +func (c *accrualCorrector) planBackfilledAdvanceSegment(ctx context.Context, input CorrectCollectedAccruedInput, source collectedSource, segment lineage.Segment, amount alpacadecimal.Decimal) ([]plannedAction, error) { + if segment.BackingTransactionGroupID == nil || *segment.BackingTransactionGroupID == "" { + return nil, fmt.Errorf("advance_backfilled segment missing backing transaction group id") + } + + // Backfilled advance means we have to unwind both the later backfill and the + // original advance-backed collection. + backingGroup, err := c.ledger.GetTransactionGroup(ctx, models.NamespacedID{ + Namespace: input.Namespace, + ID: *segment.BackingTransactionGroupID, + }) + if err != nil { + return nil, fmt.Errorf("get backing transaction group %s: %w", *segment.BackingTransactionGroupID, err) + } + + actions := make([]plannedAction, 0, 4) + if translateTx, err := c.forwardTransactionByTemplate(backingGroup, transactions.TemplateName(transactions.TranslateCustomerAccruedCostBasisTemplate{})); err == nil { + actions = append(actions, plannedTransactionCorrection{ + transaction: translateTx, + group: backingGroup, + amount: amount, + }) + } + + attributeTx, err := c.forwardTransactionByTemplate(backingGroup, transactions.TemplateName(transactions.AttributeCustomerAdvanceReceivableCostBasisTemplate{})) + if err != nil { + return nil, fmt.Errorf("find backing advance receivable attribution transaction in group %s: %w", backingGroup.ID().ID, err) + } + actions = append(actions, plannedTransactionCorrection{ + transaction: attributeTx, + group: backingGroup, + amount: amount, + }) + actions = append(actions, plannedSourceCorrectionActions(source, amount, true)...) + + // The purchased-credit-covered part becomes available credit again. + // We intentionally re-issue it into FBO and stop there: releasing purchased backing during + // correction does not trigger a fresh customer-wide backfill pass against other uncovered advance. + reissueInputs, err := c.reissueBackfilledCredit(ctx, input, backingGroup, amount) + if err != nil { + return nil, err + } + actions = append(actions, plannedDirectInputs{inputs: reissueInputs}) + + return actions, nil +} + +func (c *accrualCorrector) reissueBackfilledCredit(ctx context.Context, input CorrectCollectedAccruedInput, backingGroup ledger.TransactionGroup, amount alpacadecimal.Decimal) ([]ledger.TransactionInput, error) { + // Re-issue into the same known-cost bucket the backfill had used so the released value becomes + // ordinary purchased credit again. It can be consumed later, but we do not immediately redirect + // it onto some other uncovered advance during this correction flow. + currency, costBasis, err := c.backfilledIssueRoute(backingGroup) + if err != nil { + return nil, err + } + + resolved, err := transactions.ResolveTransactions( + ctx, + c.deps, + transactions.ResolutionScope{ + CustomerID: customer.CustomerID{ + Namespace: input.Namespace, + ID: input.CustomerID, + }, + Namespace: input.Namespace, + }, + transactions.IssueCustomerReceivableTemplate{ + At: input.AllocateAt, + Amount: amount, + Currency: currency, + CostBasis: costBasis, + }, + ) + if err != nil { + return nil, fmt.Errorf("resolve re-issued purchased credit: %w", err) + } + + out := make([]ledger.TransactionInput, 0, len(resolved)) + for _, txInput := range resolved { + out = append(out, transactions.WithAnnotations(txInput, ledger.TransactionAnnotations( + transactions.TemplateName(transactions.IssueCustomerReceivableTemplate{}), + ledger.TransactionDirectionCorrection, + ))) + } + + return out, nil +} + +func plannedSourceCorrectionActions(source collectedSource, amount alpacadecimal.Decimal, includeAdvanceReceivable bool) []plannedAction { + // A source correction always offsets the original collection transaction itself. + // Advance-backed collection also needs the companion receivable-issue correction + // so the offset reduces the advance-side obligation instead of manufacturing credit. + actions := []plannedAction{ + plannedTransactionCorrection{ + transaction: source.transaction, + group: source.group, + amount: amount, + }, + } + + if includeAdvanceReceivable && source.advanceReceivableIssueTransaction != nil { + actions = append(actions, plannedTransactionCorrection{ + transaction: source.advanceReceivableIssueTransaction, + group: source.group, + amount: amount, + }) + } + + return actions +} + +func (c *accrualCorrector) resolvePlannedInputs(ctx context.Context, input CorrectCollectedAccruedInput, actions []plannedAction) ([]ledger.TransactionInput, error) { + // Merge by original transaction before executing, so template-specific correction + // still sees one aggregated amount per source. + mergedCorrections := make(map[string]*transactionCorrectionPlan, len(actions)) + correctionOrder := make([]string, 0, len(actions)) + out := make([]ledger.TransactionInput, 0, len(actions)) + + for _, action := range actions { + switch planned := action.(type) { + case plannedTransactionCorrection: + key := planned.transaction.ID().Namespace + ":" + planned.transaction.ID().ID + if existing, ok := mergedCorrections[key]; ok { + existing.amount = existing.amount.Add(planned.amount) + continue + } + + mergedCorrections[key] = &transactionCorrectionPlan{ + transaction: planned.transaction, + group: planned.group, + amount: planned.amount, + } + correctionOrder = append(correctionOrder, key) + case plannedDirectInputs: + out = append(out, planned.inputs...) + default: + return nil, fmt.Errorf("unsupported planned action %T", action) + } + } + + for _, key := range correctionOrder { + transactionPlan := mergedCorrections[key] + correctionInputs, err := transactions.CorrectTransaction(ctx, c.deps, transactions.CorrectionInput{ + At: input.AllocateAt, + Amount: transactionPlan.amount, + OriginalTransaction: transactionPlan.transaction, + OriginalGroup: transactionPlan.group, + }) + if err != nil { + return nil, fmt.Errorf("correct transaction %s: %w", transactionPlan.transaction.ID().ID, err) + } + out = append(out, correctionInputs...) + } + + return out, nil +} + +func (c *accrualCorrector) collectedSourceBySortHint(group ledger.TransactionGroup, sortHint int) (collectedSource, error) { + sources, err := c.collectedSourcesForGroup(group) + if err != nil { + return collectedSource{}, fmt.Errorf("map correction sources for group %s: %w", group.ID().ID, err) + } + + if sortHint < 0 || sortHint >= len(sources) { + return collectedSource{}, fmt.Errorf("allocation sort hint %d out of range for transaction group %s", sortHint, group.ID().ID) + } + + return sources[sortHint], nil +} + +func (c *accrualCorrector) collectedSourcesForGroup(group ledger.TransactionGroup) ([]collectedSource, error) { + out := make([]collectedSource, 0) + for _, transaction := range group.Transactions() { + templateName, err := ledger.TransactionTemplateNameFromAnnotations(transaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction %s template name: %w", transaction.ID().ID, err) + } + + direction, err := ledger.TransactionDirectionFromAnnotations(transaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction %s direction: %w", transaction.ID().ID, err) + } + if direction != ledger.TransactionDirectionForward { + continue + } + + // Advance-backed collection comes with a receivable issue in the same group. + var advanceReceivableIssueTransaction ledger.Transaction + if templateName == transactions.TemplateName(transactions.TransferCustomerFBOAdvanceToAccruedTemplate{}) { + advanceReceivableIssueTransaction, err = c.forwardTransactionByTemplate(group, transactions.TemplateName(transactions.IssueCustomerReceivableTemplate{})) + if err != nil { + return nil, fmt.Errorf("find issue receivable companion in group %s: %w", group.ID().ID, err) + } + } + + for _, entry := range transaction.Entries() { + if entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO && entry.Amount().IsNegative() { + out = append(out, collectedSource{ + transaction: transaction, + group: group, + advanceReceivableIssueTransaction: advanceReceivableIssueTransaction, + }) + } + } + } + + return out, nil +} + +func (c *accrualCorrector) forwardTransactionByTemplate(group ledger.TransactionGroup, templateName string) (ledger.Transaction, error) { + for _, transaction := range group.Transactions() { + currentTemplateName, err := ledger.TransactionTemplateNameFromAnnotations(transaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction %s template name: %w", transaction.ID().ID, err) + } + + direction, err := ledger.TransactionDirectionFromAnnotations(transaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction %s direction: %w", transaction.ID().ID, err) + } + + if currentTemplateName == templateName && direction == ledger.TransactionDirectionForward { + return transaction, nil + } + } + + return nil, fmt.Errorf("transaction with template %s not found", templateName) +} + +func (c *accrualCorrector) backfilledIssueRoute(group ledger.TransactionGroup) (currencyx.Code, *alpacadecimal.Decimal, error) { + for _, transaction := range group.Transactions() { + for _, entry := range transaction.Entries() { + route := entry.PostingAddress().Route().Route() + if route.CostBasis != nil { + return route.Currency, route.CostBasis, nil + } + } + } + + return "", nil, fmt.Errorf("backing transaction group %s does not contain a known cost basis route", group.ID().ID) +} + +func sortCorrectionSegments(segments []lineage.Segment) []lineage.Segment { + sorted := append([]lineage.Segment(nil), segments...) + sort.SliceStable(sorted, func(i, j int) bool { + // Go from most downstream representation back outward. + precedence := func(state creditrealization.LineageSegmentState) int { + switch state { + case creditrealization.LineageSegmentStateAdvanceBackfilled: + return 0 + case creditrealization.LineageSegmentStateAdvanceUncovered: + return 1 + case creditrealization.LineageSegmentStateRealCredit: + return 2 + default: + return 3 + } + } + + return precedence(sorted[i].State) < precedence(sorted[j].State) + }) + + return sorted +} + +func minDecimal(a, b alpacadecimal.Decimal) alpacadecimal.Decimal { + if a.GreaterThan(b) { + return b + } + + return a +} diff --git a/openmeter/ledger/collector/service.go b/openmeter/ledger/collector/service.go new file mode 100644 index 0000000000..3864e35278 --- /dev/null +++ b/openmeter/ledger/collector/service.go @@ -0,0 +1,72 @@ +package collector + +import ( + "context" + "time" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +type Service interface { + CollectToAccrued(ctx context.Context, input CollectToAccruedInput) (creditrealization.CreateAllocationInputs, error) + CorrectCollectedAccrued(ctx context.Context, input CorrectCollectedAccruedInput) (creditrealization.CreateCorrectionInputs, error) +} + +type Config struct { + Ledger ledger.Ledger + Dependencies transactions.ResolverDependencies +} + +type CollectToAccruedInput struct { + Namespace string + ChargeID string + CustomerID string + At time.Time + Currency currencyx.Code + SettlementMode productcatalog.SettlementMode + ServicePeriod timeutil.ClosedPeriod + Amount alpacadecimal.Decimal +} + +type CorrectCollectedAccruedInput struct { + Namespace string + ChargeID string + CustomerID string + AllocateAt time.Time + Corrections creditrealization.CorrectionRequest + LineageSegmentsByRealization lineage.ActiveSegmentsByRealizationID +} + +type service struct { + collector *accrualCollector + corrector *accrualCorrector +} + +func NewService(config Config) Service { + return &service{ + collector: &accrualCollector{ + ledger: config.Ledger, + deps: config.Dependencies, + }, + corrector: &accrualCorrector{ + ledger: config.Ledger, + deps: config.Dependencies, + }, + } +} + +func (s *service) CollectToAccrued(ctx context.Context, input CollectToAccruedInput) (creditrealization.CreateAllocationInputs, error) { + return s.collector.collect(ctx, input) +} + +func (s *service) CorrectCollectedAccrued(ctx context.Context, input CorrectCollectedAccruedInput) (creditrealization.CreateCorrectionInputs, error) { + return s.corrector.correct(ctx, input) +} diff --git a/openmeter/ledger/customerbalance/testenv_test.go b/openmeter/ledger/customerbalance/testenv_test.go index aa7cc40b10..a249bb79a4 100644 --- a/openmeter/ledger/customerbalance/testenv_test.go +++ b/openmeter/ledger/customerbalance/testenv_test.go @@ -17,6 +17,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + lineageadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/adapter" + lineageservice "github.com/openmeterio/openmeter/openmeter/billing/charges/lineage/service" chargemeta "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" chargestestutils "github.com/openmeterio/openmeter/openmeter/billing/charges/testutils" @@ -148,6 +150,16 @@ func newTestEnv(t *testing.T) *testEnv { }) require.NoError(t, err) + lineageAdapter, err := lineageadapter.New(lineageadapter.Config{ + Client: base.DB, + }) + require.NoError(t, err) + + lineageService, err := lineageservice.New(lineageservice.Config{ + Adapter: lineageAdapter, + }) + require.NoError(t, err) + usageAdapter, err := usagebasedadapter.New(usagebasedadapter.Config{ Client: base.DB, Logger: logger, @@ -165,6 +177,7 @@ func newTestEnv(t *testing.T) *testEnv { flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ Adapter: flatFeeAdapter, Handler: handlers.FlatFee, + Lineage: lineageService, MetaAdapter: metaAdapter, Locker: locker, }) @@ -173,6 +186,7 @@ func newTestEnv(t *testing.T) *testEnv { usageService, err := usagebasedservice.New(usagebasedservice.Config{ Adapter: usageAdapter, Handler: handlers.UsageBased, + Lineage: lineageService, Locker: locker, MetaAdapter: metaAdapter, CustomerOverrideService: billingService, diff --git a/openmeter/ledger/historical/adapter/ledger.go b/openmeter/ledger/historical/adapter/ledger.go index 9fd468be10..892896598a 100644 --- a/openmeter/ledger/historical/adapter/ledger.go +++ b/openmeter/ledger/historical/adapter/ledger.go @@ -11,6 +11,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db" ledgerentrydb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgerentry" ledgertransactiondb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransaction" + ledgertransactiongroupdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransactiongroup" "github.com/openmeterio/openmeter/openmeter/ledger" ledgerhistorical "github.com/openmeterio/openmeter/openmeter/ledger/historical" "github.com/openmeterio/openmeter/pkg/currencyx" @@ -19,6 +20,65 @@ import ( "github.com/openmeterio/openmeter/pkg/slicesx" ) +func hydrateHistoricalTransaction(tx *db.LedgerTransaction) (*ledgerhistorical.Transaction, error) { + entryData, err := slicesx.MapWithErr(tx.Edges.Entries, func(entry *db.LedgerEntry) (ledgerhistorical.EntryData, error) { + subAccount, err := entry.Edges.SubAccountOrErr() + if err != nil { + return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s missing sub-account edge: %w", entry.ID, err) + } + + account, err := subAccount.Edges.AccountOrErr() + if err != nil { + return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s sub-account %s missing account edge: %w", entry.ID, subAccount.ID, err) + } + route, err := subAccount.Edges.RouteOrErr() + if err != nil { + return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s sub-account %s missing route edge: %w", entry.ID, subAccount.ID, err) + } + + return ledgerhistorical.EntryData{ + ID: entry.ID, + Namespace: entry.Namespace, + Annotations: entry.Annotations, + CreatedAt: entry.CreatedAt, + SubAccountID: entry.SubAccountID, + AccountType: account.AccountType, + Route: ledger.Route{ + Currency: currencyx.Code(route.Currency), + TaxCode: route.TaxCode, + Features: route.Features, + CostBasis: route.CostBasis, + CreditPriority: route.CreditPriority, + TransactionAuthorizationStatus: route.TransactionAuthorizationStatus, + }, + RouteID: route.ID, + RouteKey: route.RoutingKey, + RouteKeyVer: route.RoutingKeyVersion, + Amount: entry.Amount, + TransactionID: entry.TransactionID, + }, nil + }) + if err != nil { + return nil, fmt.Errorf("transaction %s entry hydration failed: %w", tx.ID, err) + } + + reconstructed, err := ledgerhistorical.NewTransactionFromData( + ledgerhistorical.TransactionData{ + ID: tx.ID, + Namespace: tx.Namespace, + Annotations: tx.Annotations, + CreatedAt: tx.CreatedAt, + BookedAt: tx.BookedAt, + }, + entryData, + ) + if err != nil { + return nil, fmt.Errorf("transaction %s: %w", tx.ID, err) + } + + return reconstructed, nil +} + func (r *repo) BookTransaction(ctx context.Context, groupID models.NamespacedID, input ledger.TransactionInput) (*ledgerhistorical.Transaction, error) { if input == nil { return nil, ledger.ErrTransactionInputRequired @@ -27,6 +87,7 @@ func (r *repo) BookTransaction(ctx context.Context, groupID models.NamespacedID, entity, err := r.db.LedgerTransaction.Create(). SetNamespace(groupID.Namespace). SetGroupID(groupID.ID). + SetAnnotations(input.Annotations()). SetBookedAt(input.BookedAt()). Save(ctx) if err != nil { @@ -113,6 +174,49 @@ func (r *repo) CreateTransactionGroup(ctx context.Context, transactionGroup ledg }, nil } +func (r *repo) GetTransactionGroup(ctx context.Context, id models.NamespacedID) (*ledgerhistorical.TransactionGroup, error) { + entity, err := r.db.LedgerTransactionGroup.Query(). + Where( + ledgertransactiongroupdb.Namespace(id.Namespace), + ledgertransactiongroupdb.ID(id.ID), + ). + WithTransactions(func(q *db.LedgerTransactionQuery) { + q.Order( + ledgertransactiondb.ByCreatedAt(), + ledgertransactiondb.ByID(), + ) + q.WithEntries(func(eq *db.LedgerEntryQuery) { + eq.Order( + ledgerentrydb.ByCreatedAt(), + ledgerentrydb.ByID(), + ) + eq.WithSubAccount(func(sq *db.LedgerSubAccountQuery) { + sq.WithAccount() + sq.WithRoute() + }) + }) + }). + Only(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query transaction group: %w", err) + } + + transactions, err := slicesx.MapWithErr(entity.Edges.Transactions, hydrateHistoricalTransaction) + if err != nil { + return nil, fmt.Errorf("failed to hydrate transaction group transactions: %w", err) + } + + return ledgerhistorical.NewTransactionGroupFromData( + ledgerhistorical.TransactionGroupData{ + ID: entity.ID, + Namespace: entity.Namespace, + CreatedAt: entity.CreatedAt, + Annotations: entity.Annotations, + }, + transactions, + ), nil +} + func (r *repo) SumEntries(ctx context.Context, query ledger.Query) (alpacadecimal.Decimal, error) { q := sumEntriesQuery{ query: query, @@ -179,62 +283,7 @@ func (r *repo) ListTransactions(ctx context.Context, input ledger.ListTransactio } items, err := slicesx.MapWithErr(paged.Items, func(tx *db.LedgerTransaction) (*ledgerhistorical.Transaction, error) { - entryData, err := slicesx.MapWithErr(tx.Edges.Entries, func(entry *db.LedgerEntry) (ledgerhistorical.EntryData, error) { - subAccount, err := entry.Edges.SubAccountOrErr() - if err != nil { - return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s missing sub-account edge: %w", entry.ID, err) - } - - account, err := subAccount.Edges.AccountOrErr() - if err != nil { - return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s sub-account %s missing account edge: %w", entry.ID, subAccount.ID, err) - } - route, err := subAccount.Edges.RouteOrErr() - if err != nil { - return ledgerhistorical.EntryData{}, fmt.Errorf("entry %s sub-account %s missing route edge: %w", entry.ID, subAccount.ID, err) - } - - return ledgerhistorical.EntryData{ - ID: entry.ID, - Namespace: entry.Namespace, - Annotations: entry.Annotations, - CreatedAt: entry.CreatedAt, - SubAccountID: entry.SubAccountID, - AccountType: account.AccountType, - Route: ledger.Route{ - Currency: currencyx.Code(route.Currency), - TaxCode: route.TaxCode, - Features: route.Features, - CostBasis: route.CostBasis, - CreditPriority: route.CreditPriority, - TransactionAuthorizationStatus: route.TransactionAuthorizationStatus, - }, - RouteID: route.ID, - RouteKey: route.RoutingKey, - RouteKeyVer: route.RoutingKeyVersion, - Amount: entry.Amount, - TransactionID: entry.TransactionID, - }, nil - }) - if err != nil { - return nil, fmt.Errorf("transaction %s entry hydration failed: %w", tx.ID, err) - } - - reconstructed, err := ledgerhistorical.NewTransactionFromData( - ledgerhistorical.TransactionData{ - ID: tx.ID, - Namespace: tx.Namespace, - Annotations: tx.Annotations, - CreatedAt: tx.CreatedAt, - BookedAt: tx.BookedAt, - }, - entryData, - ) - if err != nil { - return nil, fmt.Errorf("transaction %s: %w", tx.ID, err) - } - - return reconstructed, nil + return hydrateHistoricalTransaction(tx) }) if err != nil { return pagination.Result[*ledgerhistorical.Transaction]{}, fmt.Errorf("failed to hydrate listed transactions: %w", err) diff --git a/openmeter/ledger/historical/ledger.go b/openmeter/ledger/historical/ledger.go index 0b048c4fc0..2a2f0a0b34 100644 --- a/openmeter/ledger/historical/ledger.go +++ b/openmeter/ledger/historical/ledger.go @@ -57,6 +57,15 @@ func (l *Ledger) ListTransactions(ctx context.Context, params ledger.ListTransac }, nil } +func (l *Ledger) GetTransactionGroup(ctx context.Context, id models.NamespacedID) (ledger.TransactionGroup, error) { + group, err := l.repo.GetTransactionGroup(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get transaction group: %w", err) + } + + return group, nil +} + func (l *Ledger) CommitGroup(ctx context.Context, group ledger.TransactionGroupInput) (ledger.TransactionGroup, error) { txInputs := group.Transactions() diff --git a/openmeter/ledger/historical/repo.go b/openmeter/ledger/historical/repo.go index 321c27a203..12b15732a0 100644 --- a/openmeter/ledger/historical/repo.go +++ b/openmeter/ledger/historical/repo.go @@ -18,6 +18,9 @@ type Repo interface { // Create a transaction group CreateTransactionGroup(ctx context.Context, transactionGroup CreateTransactionGroupInput) (TransactionGroupData, error) + // Get a transaction group with hydrated transactions and entries. + GetTransactionGroup(ctx context.Context, id models.NamespacedID) (*TransactionGroup, error) + // Book a transaction BookTransaction(ctx context.Context, groupID models.NamespacedID, transaction ledger.TransactionInput) (*Transaction, error) diff --git a/openmeter/ledger/historical/transaction.go b/openmeter/ledger/historical/transaction.go index 48d1b09287..c10d566259 100644 --- a/openmeter/ledger/historical/transaction.go +++ b/openmeter/ledger/historical/transaction.go @@ -56,11 +56,22 @@ func (t *Transaction) BookedAt() time.Time { return t.data.BookedAt } +func (t *Transaction) Annotations() models.Annotations { + return t.data.Annotations +} + type TransactionGroup struct { data TransactionGroupData transactions []*Transaction } +func NewTransactionGroupFromData(data TransactionGroupData, transactions []*Transaction) *TransactionGroup { + return &TransactionGroup{ + data: data, + transactions: transactions, + } +} + var _ ledger.TransactionGroup = (*TransactionGroup)(nil) func (t *TransactionGroup) ID() models.NamespacedID { diff --git a/openmeter/ledger/noop/noop.go b/openmeter/ledger/noop/noop.go index 0bb0581476..21442249d3 100644 --- a/openmeter/ledger/noop/noop.go +++ b/openmeter/ledger/noop/noop.go @@ -75,6 +75,10 @@ func (Ledger) CommitGroup(context.Context, ledger.TransactionGroupInput) (ledger return nil, nil } +func (Ledger) GetTransactionGroup(context.Context, models.NamespacedID) (ledger.TransactionGroup, error) { + return nil, nil +} + func (Ledger) ListTransactions(context.Context, ledger.ListTransactionsInput) (pagination.Result[ledger.Transaction], error) { return pagination.Result[ledger.Transaction]{}, nil } diff --git a/openmeter/ledger/primitives.go b/openmeter/ledger/primitives.go index de62aecbf4..12c952354f 100644 --- a/openmeter/ledger/primitives.go +++ b/openmeter/ledger/primitives.go @@ -85,6 +85,7 @@ type Entry interface { type TransactionInput interface { BookedAt() time.Time EntryInputs() []EntryInput + Annotations() models.Annotations AsGroupInput(namespace string, annotations models.Annotations) TransactionGroupInput } @@ -93,6 +94,7 @@ type Transaction interface { BookedAt() time.Time Entries() []Entry ID() models.NamespacedID + Annotations() models.Annotations } type TransactionGroupInput interface { @@ -116,6 +118,9 @@ type Ledger interface { // CommitGroup commits a list of transactions on the Ledger atomically CommitGroup(ctx context.Context, group TransactionGroupInput) (TransactionGroup, error) + // GetTransactionGroup loads a previously committed transaction group including its transactions. + GetTransactionGroup(ctx context.Context, id models.NamespacedID) (TransactionGroup, error) + // ListTransactions lists transactions on the Ledger according to some filters // // TODO: Cursoring gets problematic due to diff between wall_clock and booked_at. It would be convenient to return in order of booked_at as that simplifies parsing. This API will likely change. diff --git a/openmeter/ledger/routingrules/routingrules.go b/openmeter/ledger/routingrules/routingrules.go index b765833c60..d8ed8f5215 100644 --- a/openmeter/ledger/routingrules/routingrules.go +++ b/openmeter/ledger/routingrules/routingrules.go @@ -78,10 +78,13 @@ func (r RequireFlowDirectionRule) Validate(tx TxView) error { for _, entry := range fromEntries { if !entry.Amount().IsNegative() { + if allEntriesPositive(fromEntries) && allEntriesNegative(toEntries) { + return nil + } return ledger.ErrRoutingRuleViolated.WithAttrs(models.Attributes{ "reason": "invalid_flow_direction", "account_type": r.From, - "expected": "negative", + "expected": "negative_or_positive_if_reversed", "target_type": r.To, }) } @@ -89,10 +92,13 @@ func (r RequireFlowDirectionRule) Validate(tx TxView) error { for _, entry := range toEntries { if !entry.Amount().IsPositive() { + if allEntriesPositive(fromEntries) && allEntriesNegative(toEntries) { + return nil + } return ledger.ErrRoutingRuleViolated.WithAttrs(models.Attributes{ "reason": "invalid_flow_direction", "account_type": r.To, - "expected": "positive", + "expected": "positive_or_negative_if_reversed", "source_type": r.From, }) } @@ -116,6 +122,26 @@ func hasMixedSigns(entries []EntryView) bool { return false } +func allEntriesPositive(entries []EntryView) bool { + for _, entry := range entries { + if !entry.Amount().IsPositive() { + return false + } + } + + return len(entries) > 0 +} + +func allEntriesNegative(entries []EntryView) bool { + for _, entry := range entries { + if !entry.Amount().IsNegative() { + return false + } + } + + return len(entries) > 0 +} + type RouteField string const ( @@ -210,7 +236,7 @@ func (r RequireReceivableAuthorizationStageRule) Validate(tx TxView) error { if allEntriesHaveAuthorizationStatus(negativeEntries, ledger.TransactionAuthorizationStatusOpen) && allEntriesHaveAuthorizationStatus(positiveEntries, ledger.TransactionAuthorizationStatusOpen) { - if err := requireKnownToUnknownCostBasisTranslation( + if err := requireKnownToUnknownCostBasisTranslationEitherDirection( negativeEntries, positiveEntries, ledger.AccountTypeCustomerReceivable, @@ -250,9 +276,9 @@ func (r RequireAccruedCostBasisTranslationRule) Validate(tx TxView) error { }) } - return requireKnownToUnknownCostBasisTranslation( - positiveEntries, + return requireKnownToUnknownCostBasisTranslationEitherDirection( negativeEntries, + positiveEntries, ledger.AccountTypeCustomerAccrued, []RouteField{ RouteFieldCurrency, @@ -351,6 +377,14 @@ func requireKnownToUnknownCostBasisTranslation(knownEntries, unknownEntries []En return requireMatchingRouteFields(knownEntries, unknownEntries, accountType, accountType, fields) } +func requireKnownToUnknownCostBasisTranslationEitherDirection(leftEntries, rightEntries []EntryView, accountType ledger.AccountType, fields []RouteField) error { + if err := requireKnownToUnknownCostBasisTranslation(leftEntries, rightEntries, accountType, fields); err == nil { + return nil + } + + return requireKnownToUnknownCostBasisTranslation(rightEntries, leftEntries, accountType, fields) +} + func requireMatchingRouteFields(leftEntries, rightEntries []EntryView, leftType, rightType ledger.AccountType, fields []RouteField) error { for _, left := range leftEntries { matched, err := hasMatchingRouteFields(left, rightEntries, fields) diff --git a/openmeter/ledger/routingrules/routingrules_test.go b/openmeter/ledger/routingrules/routingrules_test.go index bf7fc26e43..9887d889fd 100644 --- a/openmeter/ledger/routingrules/routingrules_test.go +++ b/openmeter/ledger/routingrules/routingrules_test.go @@ -35,6 +35,50 @@ func TestDefaultValidator_AllowsFBOToAccrued(t *testing.T) { require.NoError(t, err) } +func TestDefaultValidator_AllowsAccruedToFBO(t *testing.T) { + validator := routingrules.DefaultValidator + + err := validator.ValidateEntries([]ledger.EntryInput{ + &transactionstestutils.AnyEntryInput{ + Address: addressForRoute(t, ledger.AccountTypeCustomerAccrued, "sub-accrued", ledger.Route{ + Currency: currencyx.Code("USD"), + }), + AmountValue: alpacadecimal.NewFromInt(-50), + }, + &transactionstestutils.AnyEntryInput{ + Address: addressForRoute(t, ledger.AccountTypeCustomerFBO, "sub-fbo", ledger.Route{ + Currency: currencyx.Code("USD"), + }), + AmountValue: alpacadecimal.NewFromInt(50), + }, + }) + + require.NoError(t, err) +} + +func TestDefaultValidator_AllowsFBOToReceivableReverse(t *testing.T) { + validator := routingrules.DefaultValidator + openStatus := ledger.TransactionAuthorizationStatusOpen + + err := validator.ValidateEntries([]ledger.EntryInput{ + &transactionstestutils.AnyEntryInput{ + Address: addressForRoute(t, ledger.AccountTypeCustomerFBO, "sub-fbo", ledger.Route{ + Currency: currencyx.Code("USD"), + }), + AmountValue: alpacadecimal.NewFromInt(-50), + }, + &transactionstestutils.AnyEntryInput{ + Address: addressForRoute(t, ledger.AccountTypeCustomerReceivable, "sub-rec-open", ledger.Route{ + Currency: currencyx.Code("USD"), + TransactionAuthorizationStatus: &openStatus, + }), + AmountValue: alpacadecimal.NewFromInt(50), + }, + }) + + require.NoError(t, err) +} + func TestDefaultValidator_RejectsForbiddenAccountCombination(t *testing.T) { validator := routingrules.DefaultValidator diff --git a/openmeter/ledger/transactions/accrual.go b/openmeter/ledger/transactions/accrual.go index 8f5dc02663..b8922fd088 100644 --- a/openmeter/ledger/transactions/accrual.go +++ b/openmeter/ledger/transactions/accrual.go @@ -42,6 +42,87 @@ func (t TransferCustomerFBOToAccruedTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (TransferCustomerFBOToAccruedTemplate{}) +func (t TransferCustomerFBOToAccruedTemplate) correct(scope CorrectionInput) ([]ledger.TransactionInput, error) { + type selectedDebit struct { + fboAddress ledger.PostingAddress + accruedAddress ledger.PostingAddress + amount alpacadecimal.Decimal + } + + negativeFBOEntries := make([]ledger.Entry, 0) + accruedAddressByCostBasis := make(map[string]ledger.PostingAddress) + totalAvailable := alpacadecimal.Zero + + for _, entry := range scope.OriginalTransaction.Entries() { + switch { + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO && entry.Amount().IsNegative(): + negativeFBOEntries = append(negativeFBOEntries, entry) + totalAvailable = totalAvailable.Add(entry.Amount().Abs()) + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerAccrued && entry.Amount().IsPositive(): + accruedAddressByCostBasis[costBasisKey(entry.PostingAddress().Route().Route().CostBasis)] = entry.PostingAddress() + } + } + + if scope.Amount.GreaterThan(totalAvailable) { + return nil, fmt.Errorf("accrual correction amount %s exceeds original collected amount %s", scope.Amount.String(), totalAvailable.String()) + } + + selected := make([]selectedDebit, 0, len(negativeFBOEntries)) + remaining := scope.Amount + for idx := len(negativeFBOEntries) - 1; idx >= 0 && remaining.IsPositive(); idx-- { + entry := negativeFBOEntries[idx] + amount := entry.Amount().Abs() + if amount.GreaterThan(remaining) { + amount = remaining + } + + accruedAddress, ok := accruedAddressByCostBasis[costBasisKey(entry.PostingAddress().Route().Route().CostBasis)] + if !ok { + return nil, fmt.Errorf("missing accrued entry for FBO cost basis %s", costBasisKey(entry.PostingAddress().Route().Route().CostBasis)) + } + + selected = append(selected, selectedDebit{ + fboAddress: entry.PostingAddress(), + accruedAddress: accruedAddress, + amount: amount, + }) + remaining = remaining.Sub(amount) + } + + if remaining.IsPositive() { + return nil, fmt.Errorf("accrual correction amount %s could not be fully allocated", scope.Amount.String()) + } + + accruedAmountsByAddress := make(map[string]selectedDebit) + entryInputs := make([]*EntryInput, 0, len(selected)*2) + for _, item := range selected { + entryInputs = append(entryInputs, &EntryInput{ + address: item.fboAddress, + amount: item.amount, + }) + + key := item.accruedAddress.SubAccountID() + current := accruedAmountsByAddress[key] + current.accruedAddress = item.accruedAddress + current.amount = current.amount.Add(item.amount) + accruedAmountsByAddress[key] = current + } + + for _, item := range accruedAmountsByAddress { + entryInputs = append(entryInputs, &EntryInput{ + address: item.accruedAddress, + amount: item.amount.Neg(), + }) + } + + return []ledger.TransactionInput{ + &TransactionInput{ + bookedAt: scope.At, + entryInputs: entryInputs, + }, + }, nil +} + func (t TransferCustomerFBOToAccruedTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { collections, err := collectFromPrioritizedCustomerFBO(ctx, customerID, t.Currency, t.Amount, resolvers) if err != nil { @@ -135,9 +216,9 @@ func costBasisKey(costBasis *alpacadecimal.Decimal) string { return costBasis.String() } -// TransferCustomerFBOBucketToAccruedTemplate moves value from a specific customer FBO route -// into the matching accrued route. This is used for explicit advance-backed collection flows. -type TransferCustomerFBOBucketToAccruedTemplate struct { +// TransferCustomerFBOAdvanceToAccruedTemplate moves value from the synthetic advance-backed +// customer FBO route into the matching accrued route. +type TransferCustomerFBOAdvanceToAccruedTemplate struct { At time.Time Amount alpacadecimal.Decimal Currency currencyx.Code @@ -145,7 +226,7 @@ type TransferCustomerFBOBucketToAccruedTemplate struct { CreditPriority *int } -func (t TransferCustomerFBOBucketToAccruedTemplate) Validate() error { +func (t TransferCustomerFBOAdvanceToAccruedTemplate) Validate() error { if t.At.IsZero() { return fmt.Errorf("at is required") } @@ -173,13 +254,55 @@ func (t TransferCustomerFBOBucketToAccruedTemplate) Validate() error { return nil } -func (t TransferCustomerFBOBucketToAccruedTemplate) typeGuard() guard { +func (t TransferCustomerFBOAdvanceToAccruedTemplate) typeGuard() guard { return true } -var _ CustomerTransactionTemplate = (TransferCustomerFBOBucketToAccruedTemplate{}) +var _ CustomerTransactionTemplate = (TransferCustomerFBOAdvanceToAccruedTemplate{}) + +func (t TransferCustomerFBOAdvanceToAccruedTemplate) correct(scope CorrectionInput) ([]ledger.TransactionInput, error) { + var fboAddress ledger.PostingAddress + var accruedAddress ledger.PostingAddress + var fboAmount alpacadecimal.Decimal + var accruedAmount alpacadecimal.Decimal + + for _, entry := range scope.OriginalTransaction.Entries() { + switch { + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO && entry.Amount().IsNegative(): + fboAddress = entry.PostingAddress() + fboAmount = fboAmount.Add(entry.Amount().Abs()) + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerAccrued && entry.Amount().IsPositive(): + accruedAddress = entry.PostingAddress() + accruedAmount = accruedAmount.Add(entry.Amount()) + } + } + + if fboAddress == nil || accruedAddress == nil { + return nil, fmt.Errorf("bucket accrual correction requires original FBO and accrued entries") + } + + if scope.Amount.GreaterThan(fboAmount) || scope.Amount.GreaterThan(accruedAmount) { + return nil, fmt.Errorf("bucket accrual correction amount %s exceeds original transaction amount", scope.Amount.String()) + } + + return []ledger.TransactionInput{ + &TransactionInput{ + bookedAt: scope.At, + entryInputs: []*EntryInput{ + { + address: fboAddress, + amount: scope.Amount, + }, + { + address: accruedAddress, + amount: scope.Amount.Neg(), + }, + }, + }, + }, nil +} -func (t TransferCustomerFBOBucketToAccruedTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { +func (t TransferCustomerFBOAdvanceToAccruedTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { priority := resolveCustomerFBOCreditPriority(t.CreditPriority) customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) @@ -258,6 +381,10 @@ func (t TransferCustomerReceivableToAccruedTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (TransferCustomerReceivableToAccruedTemplate{}) +func (t TransferCustomerReceivableToAccruedTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t TransferCustomerReceivableToAccruedTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) if err != nil { @@ -346,6 +473,50 @@ func (t TranslateCustomerAccruedCostBasisTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (TranslateCustomerAccruedCostBasisTemplate{}) +func (t TranslateCustomerAccruedCostBasisTemplate) correct(scope CorrectionInput) ([]ledger.TransactionInput, error) { + var fromAccruedAddress ledger.PostingAddress + var toAccruedAddress ledger.PostingAddress + var fromAccruedAmount alpacadecimal.Decimal + var toAccruedAmount alpacadecimal.Decimal + + for _, entry := range scope.OriginalTransaction.Entries() { + switch { + case entry.PostingAddress().AccountType() != ledger.AccountTypeCustomerAccrued: + continue + case entry.Amount().IsNegative(): + fromAccruedAddress = entry.PostingAddress() + fromAccruedAmount = fromAccruedAmount.Add(entry.Amount().Abs()) + case entry.Amount().IsPositive(): + toAccruedAddress = entry.PostingAddress() + toAccruedAmount = toAccruedAmount.Add(entry.Amount()) + } + } + + if fromAccruedAddress == nil || toAccruedAddress == nil { + return nil, fmt.Errorf("accrued cost-basis translation correction requires original accrued entries") + } + + if scope.Amount.GreaterThan(fromAccruedAmount) || scope.Amount.GreaterThan(toAccruedAmount) { + return nil, fmt.Errorf("accrued cost-basis translation correction amount %s exceeds original transaction amount", scope.Amount.String()) + } + + return []ledger.TransactionInput{ + &TransactionInput{ + bookedAt: scope.At, + entryInputs: []*EntryInput{ + { + address: fromAccruedAddress, + amount: scope.Amount, + }, + { + address: toAccruedAddress, + amount: scope.Amount.Neg(), + }, + }, + }, + }, nil +} + func (t TranslateCustomerAccruedCostBasisTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) if err != nil { diff --git a/openmeter/ledger/transactions/accrual_test.go b/openmeter/ledger/transactions/accrual_test.go index f3746132aa..72823c0aa2 100644 --- a/openmeter/ledger/transactions/accrual_test.go +++ b/openmeter/ledger/transactions/accrual_test.go @@ -75,7 +75,7 @@ func TestTransferCustomerReceivableToAccruedTemplate(t *testing.T) { require.True(t, env.SumBalance(t, env.AccruedSubAccountWithCostBasis(t, &costBasis)).Equal(alpacadecimal.NewFromInt(50))) } -func TestTransferCustomerFBOBucketToAccruedTemplate_UnknownCostBasisAdvanceNetEffect(t *testing.T) { +func TestTransferCustomerFBOAdvanceToAccruedTemplate_UnknownCostBasisAdvanceNetEffect(t *testing.T) { env := newTransactionsTestEnv(t) inputs := env.resolveAndCommit( @@ -85,7 +85,7 @@ func TestTransferCustomerFBOBucketToAccruedTemplate_UnknownCostBasisAdvanceNetEf Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, }, - TransferCustomerFBOBucketToAccruedTemplate{ + TransferCustomerFBOAdvanceToAccruedTemplate{ At: env.Now(), Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, @@ -109,7 +109,7 @@ func TestTranslateCustomerAccruedCostBasisTemplate(t *testing.T) { Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, }, - TransferCustomerFBOBucketToAccruedTemplate{ + TransferCustomerFBOAdvanceToAccruedTemplate{ At: env.Now(), Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, diff --git a/openmeter/ledger/transactions/correction.go b/openmeter/ledger/transactions/correction.go new file mode 100644 index 0000000000..3a427e4287 --- /dev/null +++ b/openmeter/ledger/transactions/correction.go @@ -0,0 +1,122 @@ +package transactions + +import ( + "context" + "fmt" + "time" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/ledger" +) + +type CorrectionInput struct { + At time.Time + Amount alpacadecimal.Decimal + + OriginalTransaction ledger.Transaction + OriginalGroup ledger.TransactionGroup +} + +type CorrectionScope = CorrectionInput + +func (i CorrectionScope) Validate() error { + if i.At.IsZero() { + return fmt.Errorf("at is required") + } + + if err := ledger.ValidateTransactionAmount(i.Amount); err != nil { + return fmt.Errorf("amount: %w", err) + } + + if i.OriginalTransaction == nil { + return fmt.Errorf("original transaction is required") + } + + return nil +} + +func CorrectTransaction( + _ context.Context, + deps ResolverDependencies, + scope CorrectionScope, +) ([]ledger.TransactionInput, error) { + if err := scope.Validate(); err != nil { + return nil, fmt.Errorf("validate correction input: %w", err) + } + + direction, err := ledger.TransactionDirectionFromAnnotations(scope.OriginalTransaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction direction: %w", err) + } + + if direction == ledger.TransactionDirectionCorrection { + return nil, fmt.Errorf("cannot correct a correction transaction") + } + + templateName, err := ledger.TransactionTemplateNameFromAnnotations(scope.OriginalTransaction.Annotations()) + if err != nil { + return nil, fmt.Errorf("transaction template name: %w", err) + } + + template, err := transactionTemplateByName(templateName) + if err != nil { + return nil, err + } + + outputs, err := correctTemplate(scope, template) + if err != nil { + return nil, err + } + + annotated := make([]ledger.TransactionInput, 0, len(outputs)) + for _, output := range outputs { + annotated = append(annotated, annotateTemplateTransaction(output, template, ledger.TransactionDirectionCorrection)) + } + + return annotated, nil +} + +func transactionTemplateByName(name string) (TransactionTemplate, error) { + switch name { + case templateName(IssueCustomerReceivableTemplate{}): + return IssueCustomerReceivableTemplate{}, nil + case templateName(FundCustomerReceivableTemplate{}): + return FundCustomerReceivableTemplate{}, nil + case templateName(SettleCustomerReceivablePaymentTemplate{}): + return SettleCustomerReceivablePaymentTemplate{}, nil + case templateName(AttributeCustomerAdvanceReceivableCostBasisTemplate{}): + return AttributeCustomerAdvanceReceivableCostBasisTemplate{}, nil + case templateName(CoverCustomerReceivableTemplate{}): + return CoverCustomerReceivableTemplate{}, nil + case templateName(TransferCustomerFBOToAccruedTemplate{}): + return TransferCustomerFBOToAccruedTemplate{}, nil + case templateName(TransferCustomerFBOAdvanceToAccruedTemplate{}): + return TransferCustomerFBOAdvanceToAccruedTemplate{}, nil + case templateName(TransferCustomerReceivableToAccruedTemplate{}): + return TransferCustomerReceivableToAccruedTemplate{}, nil + case templateName(TranslateCustomerAccruedCostBasisTemplate{}): + return TranslateCustomerAccruedCostBasisTemplate{}, nil + case templateName(RecognizeEarningsFromAttributableAccruedTemplate{}): + return RecognizeEarningsFromAttributableAccruedTemplate{}, nil + case templateName(ConvertCurrencyTemplate{}): + return ConvertCurrencyTemplate{}, nil + default: + return nil, fmt.Errorf("unknown correction template %q", name) + } +} + +func correctTemplate(scope CorrectionScope, template TransactionTemplate) ([]ledger.TransactionInput, error) { + switch typ := any(template).(type) { + case CustomerTransactionTemplate: + return typ.correct(scope) + case OrgTransactionTemplate: + return typ.correct(scope) + default: + return nil, fmt.Errorf("unsupported correction template type %T", template) + } +} + +func templateCorrectionNotImplemented(name string) error { + return fmt.Errorf("%s correction is not implemented", name) +} diff --git a/openmeter/ledger/transactions/correction_test.go b/openmeter/ledger/transactions/correction_test.go new file mode 100644 index 0000000000..8d394fadfb --- /dev/null +++ b/openmeter/ledger/transactions/correction_test.go @@ -0,0 +1,72 @@ +package transactions + +import ( + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/models" +) + +type correctionTestTransaction struct { + id models.NamespacedID + bookedAt time.Time + annotations models.Annotations +} + +var _ ledger.Transaction = (*correctionTestTransaction)(nil) + +func (t *correctionTestTransaction) BookedAt() time.Time { + return t.bookedAt +} + +func (t *correctionTestTransaction) Entries() []ledger.Entry { + return nil +} + +func (t *correctionTestTransaction) ID() models.NamespacedID { + return t.id +} + +func (t *correctionTestTransaction) Annotations() models.Annotations { + return t.annotations +} + +func TestCorrectTransactionRejectsCorrectionDirection(t *testing.T) { + t.Parallel() + + _, err := CorrectTransaction(t.Context(), ResolverDependencies{}, CorrectionInput{ + At: time.Now(), + Amount: alpacadecimal.NewFromInt(1), + OriginalTransaction: &correctionTestTransaction{ + id: models.NamespacedID{Namespace: "ns", ID: "tx"}, + annotations: ledger.TransactionAnnotations( + templateName(TransferCustomerFBOToAccruedTemplate{}), + ledger.TransactionDirectionCorrection, + ), + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot correct a correction transaction") +} + +func TestCorrectTransactionDispatchesTemplateStub(t *testing.T) { + t.Parallel() + + _, err := CorrectTransaction(t.Context(), ResolverDependencies{}, CorrectionInput{ + At: time.Now(), + Amount: alpacadecimal.NewFromInt(1), + OriginalTransaction: &correctionTestTransaction{ + id: models.NamespacedID{Namespace: "ns", ID: "tx"}, + annotations: ledger.TransactionAnnotations( + templateName(FundCustomerReceivableTemplate{}), + ledger.TransactionDirectionForward, + ), + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "FundCustomerReceivableTemplate correction is not implemented") +} diff --git a/openmeter/ledger/transactions/customer.go b/openmeter/ledger/transactions/customer.go index c4e4166543..02df68c2ac 100644 --- a/openmeter/ledger/transactions/customer.go +++ b/openmeter/ledger/transactions/customer.go @@ -60,6 +60,48 @@ func (t IssueCustomerReceivableTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (IssueCustomerReceivableTemplate{}) +func (t IssueCustomerReceivableTemplate) correct(scope CorrectionInput) ([]ledger.TransactionInput, error) { + var fboAddress ledger.PostingAddress + var receivableAddress ledger.PostingAddress + var fboAmount alpacadecimal.Decimal + var receivableAmount alpacadecimal.Decimal + + for _, entry := range scope.OriginalTransaction.Entries() { + switch { + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerFBO && entry.Amount().IsPositive(): + fboAddress = entry.PostingAddress() + fboAmount = fboAmount.Add(entry.Amount()) + case entry.PostingAddress().AccountType() == ledger.AccountTypeCustomerReceivable && entry.Amount().IsNegative(): + receivableAddress = entry.PostingAddress() + receivableAmount = receivableAmount.Add(entry.Amount().Abs()) + } + } + + if fboAddress == nil || receivableAddress == nil { + return nil, fmt.Errorf("issue receivable correction requires original FBO and receivable entries") + } + + if scope.Amount.GreaterThan(fboAmount) || scope.Amount.GreaterThan(receivableAmount) { + return nil, fmt.Errorf("issue receivable correction amount %s exceeds original transaction amount", scope.Amount.String()) + } + + return []ledger.TransactionInput{ + &TransactionInput{ + bookedAt: scope.At, + entryInputs: []*EntryInput{ + { + address: fboAddress, + amount: scope.Amount.Neg(), + }, + { + address: receivableAddress, + amount: scope.Amount, + }, + }, + }, + }, nil +} + func (t IssueCustomerReceivableTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { priority := resolveCustomerFBOCreditPriority(t.CreditPriority) @@ -133,6 +175,10 @@ func (t FundCustomerReceivableTemplate) Validate() error { var _ CustomerTransactionTemplate = (FundCustomerReceivableTemplate{}) +func (t FundCustomerReceivableTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t FundCustomerReceivableTemplate) typeGuard() guard { return true } @@ -216,6 +262,10 @@ func (t SettleCustomerReceivablePaymentTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (SettleCustomerReceivablePaymentTemplate{}) +func (t SettleCustomerReceivablePaymentTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t SettleCustomerReceivablePaymentTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) if err != nil { @@ -294,6 +344,50 @@ func (t AttributeCustomerAdvanceReceivableCostBasisTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (AttributeCustomerAdvanceReceivableCostBasisTemplate{}) +func (t AttributeCustomerAdvanceReceivableCostBasisTemplate) correct(scope CorrectionInput) ([]ledger.TransactionInput, error) { + var advanceReceivableAddress ledger.PostingAddress + var attributedReceivableAddress ledger.PostingAddress + var advanceReceivableAmount alpacadecimal.Decimal + var attributedReceivableAmount alpacadecimal.Decimal + + for _, entry := range scope.OriginalTransaction.Entries() { + switch { + case entry.PostingAddress().AccountType() != ledger.AccountTypeCustomerReceivable: + continue + case entry.Amount().IsPositive(): + advanceReceivableAddress = entry.PostingAddress() + advanceReceivableAmount = advanceReceivableAmount.Add(entry.Amount()) + case entry.Amount().IsNegative(): + attributedReceivableAddress = entry.PostingAddress() + attributedReceivableAmount = attributedReceivableAmount.Add(entry.Amount().Abs()) + } + } + + if advanceReceivableAddress == nil || attributedReceivableAddress == nil { + return nil, fmt.Errorf("advance receivable attribution correction requires original receivable entries") + } + + if scope.Amount.GreaterThan(advanceReceivableAmount) || scope.Amount.GreaterThan(attributedReceivableAmount) { + return nil, fmt.Errorf("advance receivable attribution correction amount %s exceeds original transaction amount", scope.Amount.String()) + } + + return []ledger.TransactionInput{ + &TransactionInput{ + bookedAt: scope.At, + entryInputs: []*EntryInput{ + { + address: advanceReceivableAddress, + amount: scope.Amount.Neg(), + }, + { + address: attributedReceivableAddress, + amount: scope.Amount, + }, + }, + }, + }, nil +} + func (t AttributeCustomerAdvanceReceivableCostBasisTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) if err != nil { @@ -377,6 +471,10 @@ func (t CoverCustomerReceivableTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (CoverCustomerReceivableTemplate{}) +func (t CoverCustomerReceivableTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t CoverCustomerReceivableTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { priority := resolveCustomerFBOCreditPriority(t.CreditPriority) diff --git a/openmeter/ledger/transactions/earnings.go b/openmeter/ledger/transactions/earnings.go index 22a97ff0f1..1df2c26cfc 100644 --- a/openmeter/ledger/transactions/earnings.go +++ b/openmeter/ledger/transactions/earnings.go @@ -42,6 +42,10 @@ func (t RecognizeEarningsFromAttributableAccruedTemplate) typeGuard() guard { var _ CustomerTransactionTemplate = (RecognizeEarningsFromAttributableAccruedTemplate{}) +func (t RecognizeEarningsFromAttributableAccruedTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t RecognizeEarningsFromAttributableAccruedTemplate) resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) { collections, err := collectFromAttributableCustomerAccrued(ctx, customerID, t.Currency, t.Amount, resolvers) if err != nil { diff --git a/openmeter/ledger/transactions/earnings_test.go b/openmeter/ledger/transactions/earnings_test.go index 860e49b1c5..e51dfb5e2a 100644 --- a/openmeter/ledger/transactions/earnings_test.go +++ b/openmeter/ledger/transactions/earnings_test.go @@ -46,7 +46,7 @@ func TestRecognizeEarningsFromAttributableAccruedTemplate_IgnoresUnknownCostBasi Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, }, - TransferCustomerFBOBucketToAccruedTemplate{ + TransferCustomerFBOAdvanceToAccruedTemplate{ At: env.Now(), Amount: alpacadecimal.NewFromInt(30), Currency: env.Currency, diff --git a/openmeter/ledger/transactions/fx.go b/openmeter/ledger/transactions/fx.go index 9b7b53740b..7cd1201174 100644 --- a/openmeter/ledger/transactions/fx.go +++ b/openmeter/ledger/transactions/fx.go @@ -55,6 +55,10 @@ func (t ConvertCurrencyTemplate) Validate() error { var _ CustomerTransactionTemplate = (ConvertCurrencyTemplate{}) +func (t ConvertCurrencyTemplate) correct(CorrectionInput) ([]ledger.TransactionInput, error) { + return nil, templateCorrectionNotImplemented(templateName(t)) +} + func (t ConvertCurrencyTemplate) typeGuard() guard { return true } diff --git a/openmeter/ledger/transactions/input.go b/openmeter/ledger/transactions/input.go index 651fdf03ff..32493aa7d8 100644 --- a/openmeter/ledger/transactions/input.go +++ b/openmeter/ledger/transactions/input.go @@ -32,6 +32,7 @@ func (e *EntryInput) Amount() alpacadecimal.Decimal { type TransactionInput struct { bookedAt time.Time entryInputs []*EntryInput + annotations models.Annotations } // ---------------------------------------------------------------------------- @@ -50,6 +51,10 @@ func (t *TransactionInput) EntryInputs() []ledger.EntryInput { }) } +func (t *TransactionInput) Annotations() models.Annotations { + return t.annotations +} + func (t *TransactionInput) AsGroupInput(namespace string, annotations models.Annotations) ledger.TransactionGroupInput { return &TransactionGroupInput{ namespace: namespace, @@ -66,6 +71,34 @@ func GroupInputs(namespace string, annotations models.Annotations, inputs ...led } } +func WithAnnotations(input ledger.TransactionInput, annotations models.Annotations) ledger.TransactionInput { + merged := make(models.Annotations, len(input.Annotations())+len(annotations)) + + for key, value := range input.Annotations() { + merged[key] = value + } + + for key, value := range annotations { + merged[key] = value + } + + return &annotatedTransactionInput{ + TransactionInput: input, + annotations: merged, + } +} + +type annotatedTransactionInput struct { + ledger.TransactionInput + annotations models.Annotations +} + +var _ ledger.TransactionInput = (*annotatedTransactionInput)(nil) + +func (a *annotatedTransactionInput) Annotations() models.Annotations { + return a.annotations +} + type TransactionGroupInput struct { namespace string transactions []ledger.TransactionInput diff --git a/openmeter/ledger/transactions/resolve.go b/openmeter/ledger/transactions/resolve.go index fc8df0ea39..1879a9f871 100644 --- a/openmeter/ledger/transactions/resolve.go +++ b/openmeter/ledger/transactions/resolve.go @@ -3,6 +3,7 @@ package transactions import ( "context" "fmt" + "reflect" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ledger" @@ -60,20 +61,12 @@ func (s ResolutionScope) validateForOrgTransaction() error { return nil } -type ( - guard bool // private type guard - Resolver interface { - typeGuard() guard - Validate() error - } -) - // ResolveTransactions resolves a list of transaction templates into a list of transaction inputs func ResolveTransactions( ctx context.Context, deps ResolverDependencies, scope ResolutionScope, - templates ...Resolver, + templates ...TransactionTemplate, ) ([]ledger.TransactionInput, error) { if err := scope.Validate(); err != nil { return nil, err @@ -98,7 +91,7 @@ func ResolveTransactions( } if tx != nil { - inputs = append(inputs, tx) + inputs = append(inputs, annotateTemplateTransaction(tx, template, ledger.TransactionDirectionForward)) } case OrgTransactionTemplate: if err := scope.validateForOrgTransaction(); err != nil { @@ -111,7 +104,7 @@ func ResolveTransactions( } if tx != nil { - inputs = append(inputs, tx) + inputs = append(inputs, annotateTemplateTransaction(tx, template, ledger.TransactionDirectionForward)) } default: return nil, ledger.ErrResolutionTemplateUnknown.WithAttrs(models.Attributes{ @@ -122,3 +115,20 @@ func ResolveTransactions( return inputs, nil } + +func templateName(template TransactionTemplate) string { + typ := reflect.TypeOf(template) + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + return typ.Name() +} + +func TemplateName(template TransactionTemplate) string { + return templateName(template) +} + +func annotateTemplateTransaction(input ledger.TransactionInput, template TransactionTemplate, direction ledger.TransactionDirection) ledger.TransactionInput { + return WithAnnotations(input, ledger.TransactionAnnotations(templateName(template), direction)) +} diff --git a/openmeter/ledger/transactions/resolve_test.go b/openmeter/ledger/transactions/resolve_test.go index 69e3ed89fd..ae162475b2 100644 --- a/openmeter/ledger/transactions/resolve_test.go +++ b/openmeter/ledger/transactions/resolve_test.go @@ -28,6 +28,10 @@ func (*spyCustomerTemplate) resolve(context.Context, customer.CustomerID, Resolv return nil, nil } +func (*spyCustomerTemplate) correct(CorrectionScope) ([]ledger.TransactionInput, error) { + return nil, nil +} + var _ CustomerTransactionTemplate = (*spyCustomerTemplate)(nil) func TestResolveTransactions_callsResolverValidate(t *testing.T) { @@ -46,5 +50,43 @@ func TestResolveTransactions_callsResolverValidate(t *testing.T) { spy, ) require.NoError(t, err) - require.Equal(t, 1, spy.validateCalls, "Resolver.Validate must be invoked for each template") + require.Equal(t, 1, spy.validateCalls, "TransactionTemplate.Validate must be invoked for each template") +} + +type annotatedCustomerTemplate struct{} + +func (annotatedCustomerTemplate) Validate() error { + return nil +} + +func (annotatedCustomerTemplate) typeGuard() guard { + return true +} + +func (annotatedCustomerTemplate) resolve(_ context.Context, _ customer.CustomerID, _ ResolverDependencies) (ledger.TransactionInput, error) { + return &TransactionInput{}, nil +} + +func (annotatedCustomerTemplate) correct(CorrectionScope) ([]ledger.TransactionInput, error) { + return nil, nil +} + +func TestResolveTransactions_addsTemplateAnnotations(t *testing.T) { + t.Parallel() + + inputs, err := ResolveTransactions( + t.Context(), + ResolverDependencies{}, + ResolutionScope{ + CustomerID: customer.CustomerID{ + Namespace: "ns", + ID: "cust", + }, + }, + annotatedCustomerTemplate{}, + ) + require.NoError(t, err) + require.Len(t, inputs, 1) + require.Equal(t, "annotatedCustomerTemplate", inputs[0].Annotations()[ledger.AnnotationTransactionTemplateName]) + require.Equal(t, string(ledger.TransactionDirectionForward), inputs[0].Annotations()[ledger.AnnotationTransactionDirection]) } diff --git a/openmeter/ledger/transactions/template.go b/openmeter/ledger/transactions/template.go index b140b0a0c8..1f81e01fe7 100644 --- a/openmeter/ledger/transactions/template.go +++ b/openmeter/ledger/transactions/template.go @@ -7,18 +7,28 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger" ) +type ( + guard bool // private type guard + TransactionTemplate interface { + typeGuard() guard + Validate() error + } +) + // CustomerTransactionTemplate is a template for customer scoped transactions type CustomerTransactionTemplate interface { - Resolver + TransactionTemplate // Resolve resolves the template's intent for a concrete customer resolve(ctx context.Context, customerID customer.CustomerID, resolvers ResolverDependencies) (ledger.TransactionInput, error) + correct(scope CorrectionScope) ([]ledger.TransactionInput, error) } // OrgTransactionTemplate is a template for organization scoped transactions type OrgTransactionTemplate interface { - Resolver + TransactionTemplate // Resolve resolves the template's intent for a given organization resolve(ctx context.Context, namespace string, resolvers ResolverDependencies) (ledger.TransactionInput, error) + correct(scope CorrectionScope) ([]ledger.TransactionInput, error) } diff --git a/openmeter/ledger/transactions/testenv_test.go b/openmeter/ledger/transactions/testenv_test.go index 8af0ee6c20..53d710778a 100644 --- a/openmeter/ledger/transactions/testenv_test.go +++ b/openmeter/ledger/transactions/testenv_test.go @@ -27,7 +27,7 @@ func (e *transactionsTestEnv) resolverDeps() ResolverDependencies { } } -func (e *transactionsTestEnv) resolve(t *testing.T, templates ...Resolver) []ledger.TransactionInput { +func (e *transactionsTestEnv) resolve(t *testing.T, templates ...TransactionTemplate) []ledger.TransactionInput { t.Helper() inputs, err := ResolveTransactions( @@ -51,7 +51,7 @@ func (e *transactionsTestEnv) commit(t *testing.T, inputs ...ledger.TransactionI require.NoError(t, err) } -func (e *transactionsTestEnv) resolveAndCommit(t *testing.T, templates ...Resolver) []ledger.TransactionInput { +func (e *transactionsTestEnv) resolveAndCommit(t *testing.T, templates ...TransactionTemplate) []ledger.TransactionInput { t.Helper() inputs := e.resolve(t, templates...) diff --git a/openmeter/ledger/transactions/testutils/anytransaction.go b/openmeter/ledger/transactions/testutils/anytransaction.go index 3f4332ea09..9c66da136c 100644 --- a/openmeter/ledger/transactions/testutils/anytransaction.go +++ b/openmeter/ledger/transactions/testutils/anytransaction.go @@ -28,6 +28,7 @@ func (a *AnyEntryInput) Amount() alpacadecimal.Decimal { type AnyTransactionInput struct { BookedAtValue time.Time EntryInputsValues []*AnyEntryInput + AnnotationsValue models.Annotations } var _ ledger.TransactionInput = (*AnyTransactionInput)(nil) @@ -42,6 +43,10 @@ func (a *AnyTransactionInput) EntryInputs() []ledger.EntryInput { }) } +func (a *AnyTransactionInput) Annotations() models.Annotations { + return a.AnnotationsValue +} + func (a *AnyTransactionInput) AsGroupInput(namespace string, annotations models.Annotations) ledger.TransactionGroupInput { return &AnyTransactionGroupInput{NamespaceValue: namespace, TransactionsValues: []*AnyTransactionInput{a}, AnnotationsValue: annotations} } diff --git a/test/credits/sanity_lifecycle_test.go b/test/credits/sanity_lifecycle_test.go new file mode 100644 index 0000000000..35662cdb47 --- /dev/null +++ b/test/credits/sanity_lifecycle_test.go @@ -0,0 +1,606 @@ +package credits + +import ( + "context" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + "github.com/samber/mo" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/charges/models/payment" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + streamingtestutils "github.com/openmeterio/openmeter/openmeter/streaming/testutils" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/datetime" + "github.com/openmeterio/openmeter/pkg/timeutil" + billingtest "github.com/openmeterio/openmeter/test/billing" +) + +type usageBasedPartialBackfillLifecycleState struct { + customerID customer.CustomerID + usageChargeID meta.ChargeID + creditPurchaseChargeID meta.ChargeID + purchaseAmount alpacadecimal.Decimal + costBasis alpacadecimal.Decimal +} + +func (s *CreditsTestSuite) TestUsageBasedCreditOnlyLifecyclePartialBackfillCorrectionThenDeleteSanity() { + ctx := context.Background() + state := s.setupUsageBasedCreditOnlyLifecyclePartialBackfillCorrection(ctx, "charges-sanity-usagebased-credit-only-lifecycle-partial-backfill-correction-delete") + + // When the now-corrected charge is deleted with refund-as-credits, the delete path has to use + // the already-written-back lineage state rather than the original pre-correction split. + err := s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: state.customerID, + PatchesByChargeID: map[string]charges.Patch{ + state.usageChargeID.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) + + // Then the corrected usage is fully unwound. The only remaining open receivable is the still-unsettled + // purchase-side obligation in the purchased cost-basis bucket. + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(state.purchaseAmount.Neg(), s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some(&state.costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some(&state.costBasis))) + s.Equal(state.purchaseAmount, s.mustCustomerFBOBalance(state.customerID, USD, mo.Some(&state.costBasis))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil))) + + // When we close the later credit purchase payment lifecycle too. + s.mustSettleExternalCreditPurchase(ctx, state.creditPurchaseChargeID) + + // Then the purchased-cost-basis receivable is fully cleaned up, while the refunded purchased + // credits stay available in FBO. The remaining nil-cost-basis receivable is also netted out here. + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some(&state.costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some(&state.costBasis))) + s.Equal(state.purchaseAmount, s.mustCustomerFBOBalance(state.customerID, USD, mo.Some(&state.costBasis))) +} + +func (s *CreditsTestSuite) TestUsageBasedCreditOnlyLifecyclePartialBackfillCorrectionSettleBeforeDeleteSanity() { + ctx := context.Background() + state := s.setupUsageBasedCreditOnlyLifecyclePartialBackfillCorrection(ctx, "charges-sanity-usagebased-credit-only-lifecycle-partial-backfill-correction-settle-before-delete") + + // When we close the later credit purchase payment lifecycle before refunding the original charge. + s.mustSettleExternalCreditPurchase(ctx, state.creditPurchaseChargeID) + + // Then the purchased receivable is already cleaned up, but the corrected purchased-credit-backed + // usage is still split between accrued and available FBO. + s.Equal(alpacadecimal.NewFromInt(-5), s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some(&state.costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(6), s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some(&state.costBasis))) + s.Equal(alpacadecimal.NewFromInt(9), s.mustCustomerFBOBalance(state.customerID, USD, mo.Some(&state.costBasis))) + + // When the original charge is deleted with refund-as-credits afterwards. + err := s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: state.customerID, + PatchesByChargeID: map[string]charges.Patch{ + state.usageChargeID.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) + + // Then the end state is fully cleaned up: the purchase is settled, the corrected usage is refunded, + // and no receivable remains open on either route. + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(state.customerID, USD, mo.Some(&state.costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(state.customerID, USD, mo.Some(&state.costBasis))) + s.Equal(state.purchaseAmount, s.mustCustomerFBOBalance(state.customerID, USD, mo.Some(&state.costBasis))) +} + +func (s *CreditsTestSuite) TestUsageBasedCreditOnlyLifecycleTwoChargesTwoPurchasesSanity() { + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-sanity-usagebased-credit-only-lifecycle-two-charges-two-purchases") + + cust := s.createLedgerBackedCustomer(ns, "test-subject") + sandboxApp := s.InstallSandboxApp(s.T(), ns) + _ = s.ProvisionBillingProfile(ctx, ns, sandboxApp.GetID(), + billingtest.WithProgressiveBilling(), + billingtest.WithCollectionInterval(datetime.MustParseDuration(s.T(), "P2D")), + billingtest.WithManualApproval(), + ) + + apiRequestsTotal := s.SetupApiRequestsTotalFeature(ctx, ns) + meterSlug := apiRequestsTotal.Feature.Key + + servicePeriodA := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + servicePeriodB := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-03-01T00:00:00Z", time.UTC).AsTime(), + } + chargeAFinalAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-03T00:01:00Z", time.UTC).AsTime() + chargeBCreateAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-15T00:00:00Z", time.UTC).AsTime() + chargeBStartFinalizationAt := datetime.MustParseTimeInLocation(s.T(), "2026-03-01T12:00:00Z", time.UTC).AsTime() + chargeBFinalizeAt := datetime.MustParseTimeInLocation(s.T(), "2026-03-03T00:01:00Z", time.UTC).AsTime() + purchase1Amount := alpacadecimal.NewFromInt(25) + purchase2Amount := alpacadecimal.NewFromInt(10) + costBasis1 := alpacadecimal.NewFromFloat(0.5) + costBasis2 := alpacadecimal.NewFromFloat(0.8) + + clock.FreezeTime(chargeAFinalAt) + defer clock.UnFreeze() + + // Given Charge A belongs to an older service period and is created after its collection window. + // When it is created with 20 units at $1/unit, it finalizes immediately as advance-backed usage. + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 20, + datetime.MustParseTimeInLocation(s.T(), "2026-01-15T00:00:00Z", time.UTC).AsTime(), + ) + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriodA, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromInt(1)}), + name: "usage-based-charge-a", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: "usage-based-charge-a", + featureKey: meterSlug, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + chargeA, err := res[0].AsUsageBasedCharge() + s.NoError(err) + s.Equal(meta.ChargeStatusFinal, meta.ChargeStatus(chargeA.Status)) + s.Equal(alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + + // Given Charge B belongs to the next service period. + // When it is created while that service period is already active, it starts in Active with no allocation yet. + clock.FreezeTime(chargeBCreateAt) + priceB := productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + Mode: productcatalog.VolumeTieredPrice, + Tiers: []productcatalog.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromInt(10)), + UnitPrice: &productcatalog.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromInt(2), + }, + }, + { + UpToAmount: nil, + UnitPrice: &productcatalog.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromInt(1), + }, + }, + }, + }) + res, err = s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriodB, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: priceB, + name: "usage-based-charge-b", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: "usage-based-charge-b", + featureKey: meterSlug, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + chargeB, err := res[0].AsUsageBasedCharge() + s.NoError(err) + s.Equal(meta.ChargeStatusActive, meta.ChargeStatus(chargeB.Status)) + s.Equal(alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + + // Given Charge B records 10 units during its own service period. + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 10, + datetime.MustParseTimeInLocation(s.T(), "2026-02-20T00:00:00Z", time.UTC).AsTime(), + ) + + // When Charge B starts finalization, it allocates 20 more advance-backed credits. + clock.FreezeTime(chargeBStartFinalizationAt) + advancedChargeB := s.mustAdvanceUsageBasedChargeByID(ctx, cust.GetID(), chargeB.GetChargeID()) + s.Require().NotNil(advancedChargeB) + s.Equal(usagebased.StatusActiveFinalRealizationWaitingForCollection, advancedChargeB.Status) + s.Equal(alpacadecimal.NewFromInt(-40), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(40), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + + // Given the first later credit purchase arrives while both charges still contribute uncovered advance. + // When the customer buys 25 credits at cost basis 0.5, it backfills the older uncovered usage first. + res, err = s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createCreditPurchaseIntent(createCreditPurchaseIntentInput{ + customer: cust.GetID(), + currency: USD, + amount: purchase1Amount, + servicePeriod: timeutil.ClosedPeriod{ + From: chargeBStartFinalizationAt, + To: chargeBStartFinalizationAt, + }, + settlement: creditpurchase.NewSettlement(creditpurchase.ExternalSettlement{ + GenericSettlement: creditpurchase.GenericSettlement{ + Currency: USD, + CostBasis: costBasis1, + }, + InitialStatus: creditpurchase.CreatedInitialPaymentSettlementStatus, + }), + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + purchase1Charge, err := res[0].AsCreditPurchaseCharge() + s.NoError(err) + s.Equal(alpacadecimal.NewFromInt(-40), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(-15), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(15), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchase1Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis1), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(purchase1Amount, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + + // Given one more unit becomes visible for Charge B before the final cutoff. + // This reduces Charge B's priced amount from 20 down to 11, so part of Purchase 1 is released again. + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 1, + datetime.MustParseTimeInLocation(s.T(), "2026-02-21T00:00:00Z", time.UTC).AsTime(), + streamingtestutils.WithStoredAt(datetime.MustParseTimeInLocation(s.T(), "2026-03-02T00:00:00Z", time.UTC).AsTime()), + ) + + // When Charge B finalizes, the lifecycle-driven correction should free the 5 cost-basis-backed + // part first and only then reduce the still-uncovered remainder. + // That 5 is the portion of Purchase 1 that had already been attributed to Charge B after + // fully backfilling Charge A's older 20 first. + // !!! Released purchased credit goes back to FBO here. It does not immediately snap onto + // Charge B's or any other charge's remaining uncovered advance. Only a later purchase/initiation + // pass will backfill uncovered advance again. + clock.FreezeTime(chargeBFinalizeAt) + advancedChargeB = s.mustAdvanceUsageBasedChargeByID(ctx, cust.GetID(), chargeB.GetChargeID()) + s.Require().NotNil(advancedChargeB) + s.Equal(meta.ChargeStatusFinal, meta.ChargeStatus(advancedChargeB.Status)) + s.Equal(alpacadecimal.NewFromInt(-36), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + // After the correction, Charge A still accounts for the full 20 costBasis1-backed usage, + // while Charge B drops back to 11 uncovered usage and releases those 5 purchased credits to FBO. + s.Equal(alpacadecimal.NewFromInt(-11), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(11), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchase1Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis1), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + + // Given a second later credit purchase now sees only Charge B's remaining uncovered amount. + // !!! The released 5 from Purchase 1 stayed as available purchased credit in FBO; it did not + // auto-cover this remaining uncovered advance on its own. + // When the customer buys another 10 credits at a different cost basis, it should backfill only Charge B. + clock.FreezeTime(chargeBFinalizeAt.Add(time.Minute)) + res, err = s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createCreditPurchaseIntent(createCreditPurchaseIntentInput{ + customer: cust.GetID(), + currency: USD, + amount: purchase2Amount, + servicePeriod: timeutil.ClosedPeriod{ + From: clock.Now(), + To: clock.Now(), + }, + settlement: creditpurchase.NewSettlement(creditpurchase.ExternalSettlement{ + GenericSettlement: creditpurchase.GenericSettlement{ + Currency: USD, + CostBasis: costBasis2, + }, + InitialStatus: creditpurchase.CreatedInitialPaymentSettlementStatus, + }), + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + purchase2Charge, err := res[0].AsCreditPurchaseCharge() + s.NoError(err) + s.Equal(alpacadecimal.NewFromInt(-36), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(-1), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(1), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchase1Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis1), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(purchase2Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis2), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(purchase2Amount, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis2))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis2))) + + // When Charge B is refunded, only its current backing should be released. + s.mustRefundCharge(ctx, cust.GetID(), chargeB.GetChargeID()) + s.Equal(alpacadecimal.NewFromInt(-35), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchase1Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis1), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(purchase2Amount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis2), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis2))) + s.Equal(purchase2Amount, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis2))) + + // When both later purchases complete their payment lifecycle too. + s.mustSettleExternalCreditPurchase(ctx, purchase1Charge.GetChargeID()) + s.mustSettleExternalCreditPurchase(ctx, purchase2Charge.GetChargeID()) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis1), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis1))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis2), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis2))) + s.Equal(purchase2Amount, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis2))) +} + +// Use this helper for the shared single-charge lifecycle setup that stops after +// the later correction has already been applied. +func (s *CreditsTestSuite) setupUsageBasedCreditOnlyLifecyclePartialBackfillCorrection(ctx context.Context, namespacePrefix string) usageBasedPartialBackfillLifecycleState { + ns := s.GetUniqueNamespace(namespacePrefix) + + cust := s.createLedgerBackedCustomer(ns, "test-subject") + sandboxApp := s.InstallSandboxApp(s.T(), ns) + _ = s.ProvisionBillingProfile(ctx, ns, sandboxApp.GetID(), + billingtest.WithProgressiveBilling(), + billingtest.WithCollectionInterval(datetime.MustParseDuration(s.T(), "P2D")), + billingtest.WithManualApproval(), + ) + + apiRequestsTotal := s.SetupApiRequestsTotalFeature(ctx, ns) + meterSlug := apiRequestsTotal.Feature.Key + + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + createAt := datetime.MustParseTimeInLocation(s.T(), "2025-12-01T00:00:00Z", time.UTC).AsTime() + startFinalizationAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-01T12:00:00Z", time.UTC).AsTime() + finalizeAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-03T00:01:00Z", time.UTC).AsTime() + purchaseAmount := alpacadecimal.NewFromInt(15) + costBasis := alpacadecimal.NewFromFloat(0.5) + + clock.FreezeTime(createAt) + defer clock.UnFreeze() + + price := productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + Mode: productcatalog.VolumeTieredPrice, + Tiers: []productcatalog.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromInt(10)), + UnitPrice: &productcatalog.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromInt(2), + }, + }, + { + UpToAmount: nil, + UnitPrice: &productcatalog.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromInt(1), + }, + }, + }, + }) + + // Given current wall clock is 2025-12-01T00:00:00Z, well before the service period. + // When creating a credit-only usage-based charge with a tiered price. + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: price, + name: "usage-based-credit-only-lifecycle-partial-backfill-correction-delete", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: namespacePrefix, + featureKey: meterSlug, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + usageCharge, err := res[0].AsUsageBasedCharge() + s.NoError(err) + + // Then the first advance at service period start only moves the charge into Active. + clock.FreezeTime(servicePeriod.From) + advancedCharge := s.mustAdvanceSingleUsageBasedCharge(ctx, cust.GetID()) + s.Require().NotNil(advancedCharge) + s.Equal(meta.ChargeStatusActive, meta.ChargeStatus(advancedCharge.Status)) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis))) + + // Given the customer records 10 units during the service period. + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 10, + datetime.MustParseTimeInLocation(s.T(), "2026-01-15T00:00:00Z", time.UTC).AsTime(), + ) + + // When we advance after the service period, the final realization starts and allocates the + // initial 20 credits, but the charge still waits for the collection window to close. + clock.FreezeTime(startFinalizationAt) + advancedCharge = s.mustAdvanceSingleUsageBasedCharge(ctx, cust.GetID()) + s.Require().NotNil(advancedCharge) + s.Equal(usagebased.StatusActiveFinalRealizationWaitingForCollection, advancedCharge.Status) + s.Equal(alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(20), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis))) + + // Given a later external credit purchase partially backfills that earlier advance-backed usage. + res, err = s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createCreditPurchaseIntent(createCreditPurchaseIntentInput{ + customer: cust.GetID(), + currency: USD, + amount: purchaseAmount, + servicePeriod: timeutil.ClosedPeriod{ + From: startFinalizationAt, + To: startFinalizationAt, + }, + settlement: creditpurchase.NewSettlement(creditpurchase.ExternalSettlement{ + GenericSettlement: creditpurchase.GenericSettlement{ + Currency: USD, + CostBasis: costBasis, + }, + InitialStatus: creditpurchase.CreatedInitialPaymentSettlementStatus, + }), + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + creditPurchaseCharge, err := res[0].AsCreditPurchaseCharge() + s.NoError(err) + s.Equal(alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(-5), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchaseAmount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(purchaseAmount, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis))) + s.Equal(alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis))) + + // Given one more unit becomes visible before the final stored_at cutoff. + // This shrinks the previously allocated amount from 20 down to 11 during finalization. + s.MockStreamingConnector.AddSimpleEvent( + meterSlug, + 1, + datetime.MustParseTimeInLocation(s.T(), "2026-01-20T00:00:00Z", time.UTC).AsTime(), + streamingtestutils.WithStoredAt(datetime.MustParseTimeInLocation(s.T(), "2026-02-02T00:00:00Z", time.UTC).AsTime()), + ) + + // When we advance after the collection window, the normal lifecycle issues the correction. + clock.FreezeTime(finalizeAt) + advancedCharge = s.mustAdvanceSingleUsageBasedCharge(ctx, cust.GetID()) + s.Require().NotNil(advancedCharge) + s.Equal(meta.ChargeStatusFinal, meta.ChargeStatus(advancedCharge.Status)) + s.Equal(alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(-5), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(5), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + s.Equal(purchaseAmount.Neg(), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen)) + s.Equal(alpacadecimal.NewFromInt(6), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis))) + s.Equal(alpacadecimal.NewFromInt(9), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis))) + + return usageBasedPartialBackfillLifecycleState{ + customerID: cust.GetID(), + usageChargeID: usageCharge.GetChargeID(), + creditPurchaseChargeID: creditPurchaseCharge.GetChargeID(), + purchaseAmount: purchaseAmount, + costBasis: costBasis, + } +} + +// Use this helper when the test wants to drive a purchase through the normal +// external-payment authorized -> settled lifecycle. +func (s *CreditsTestSuite) mustSettleExternalCreditPurchase(ctx context.Context, chargeID meta.ChargeID) { + s.T().Helper() + + updatedCharge, err := s.Charges.HandleCreditPurchaseExternalPaymentStateTransition(ctx, charges.HandleCreditPurchaseExternalPaymentStateTransitionInput{ + ChargeID: chargeID, + TargetPaymentState: payment.StatusAuthorized, + }) + s.NoError(err) + s.Equal(payment.StatusAuthorized, updatedCharge.State.ExternalPaymentSettlement.Status) + + updatedCharge, err = s.Charges.HandleCreditPurchaseExternalPaymentStateTransition(ctx, charges.HandleCreditPurchaseExternalPaymentStateTransitionInput{ + ChargeID: chargeID, + TargetPaymentState: payment.StatusSettled, + }) + s.NoError(err) + s.Equal(payment.StatusSettled, updatedCharge.State.ExternalPaymentSettlement.Status) +} + +// Use this helper when the test wants to delete a charge through the real +// refund-as-credits patch flow. +func (s *CreditsTestSuite) mustRefundCharge(ctx context.Context, customerID customer.CustomerID, chargeID meta.ChargeID) { + s.T().Helper() + + err := s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: customerID, + PatchesByChargeID: map[string]charges.Patch{ + chargeID.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) +} + +// Use this helper when one advance call may return multiple usage-based charges +// and the test cares about the transition for one specific charge. +func (s *CreditsTestSuite) mustAdvanceUsageBasedChargeByID(ctx context.Context, customerID customer.CustomerID, chargeID meta.ChargeID) *usagebased.Charge { + s.T().Helper() + + advancedCharges, err := s.Charges.AdvanceCharges(ctx, charges.AdvanceChargesInput{ + Customer: customerID, + }) + s.NoError(err) + + for _, charge := range advancedCharges { + if charge.Type() != meta.ChargeTypeUsageBased { + continue + } + + advancedCharge, err := charge.AsUsageBasedCharge() + s.NoError(err) + + if advancedCharge.GetChargeID() == chargeID { + return &advancedCharge + } + } + + return nil +} + +// Use this helper when the test expects exactly one usage-based charge to advance. +func (s *CreditsTestSuite) mustAdvanceSingleUsageBasedCharge(ctx context.Context, customerID customer.CustomerID) *usagebased.Charge { + s.T().Helper() + + advancedCharges, err := s.Charges.AdvanceCharges(ctx, charges.AdvanceChargesInput{ + Customer: customerID, + }) + s.NoError(err) + + if len(advancedCharges) == 0 { + return nil + } + + s.Len(advancedCharges, 1) + s.Equal(meta.ChargeTypeUsageBased, advancedCharges[0].Type()) + + advancedCharge, err := advancedCharges[0].AsUsageBasedCharge() + s.NoError(err) + + return &advancedCharge +} diff --git a/test/credits/sanity_test.go b/test/credits/sanity_test.go index f4452ca784..2a413092a0 100644 --- a/test/credits/sanity_test.go +++ b/test/credits/sanity_test.go @@ -26,13 +26,16 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger" ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" ledgerchargeadapter "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" ledgerresolvers "github.com/openmeterio/openmeter/openmeter/ledger/resolvers" ledgertestutils "github.com/openmeterio/openmeter/openmeter/ledger/testutils" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog" omtestutils "github.com/openmeterio/openmeter/openmeter/testutils" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/datetime" + "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/timeutil" billingtest "github.com/openmeterio/openmeter/test/billing" ) @@ -64,20 +67,300 @@ func (s *CreditsTestSuite) SetupSuite() { s.LedgerAccountService = deps.AccountService s.LedgerResolver = deps.ResolversService + collectorService := ledgercollector.NewService(ledgercollector.Config{ + Ledger: deps.HistoricalLedger, + Dependencies: transactions.ResolverDependencies{ + AccountService: deps.ResolversService, + SubAccountService: deps.AccountService, + }, + }) + stack, err := chargestestutils.NewServices(s.T(), chargestestutils.Config{ Client: s.DBClient, Logger: logger, BillingService: s.BillingService, FeatureService: s.FeatureService, StreamingConnector: s.MockStreamingConnector, - FlatFeeHandler: ledgerchargeadapter.NewFlatFeeHandler(deps.HistoricalLedger, deps.ResolversService, deps.AccountService), + FlatFeeHandler: ledgerchargeadapter.NewFlatFeeHandler(deps.HistoricalLedger, transactions.ResolverDependencies{AccountService: deps.ResolversService, SubAccountService: deps.AccountService}, collectorService), CreditPurchaseHandler: ledgerchargeadapter.NewCreditPurchaseHandler(deps.HistoricalLedger, deps.ResolversService, deps.AccountService), - UsageBasedHandler: usagebased.UnimplementedHandler{}, + UsageBasedHandler: ledgerchargeadapter.NewUsageBasedHandler(collectorService), }) s.NoError(err) s.Charges = stack.ChargesService } +func (s *CreditsTestSuite) TearDownTest() { + s.MockStreamingConnector.Reset() + clock.UnFreeze() + clock.ResetTime() +} + +func (s *CreditsTestSuite) TestFlatFeeCreditOnlyDeleteCorrectionSanity() { + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-sanity-flatfee-credit-only-delete") + + customInvoicing := s.SetupCustomInvoicing(ns) + cust := s.createLedgerBackedCustomer(ns, "test-subject") + + _ = s.ProvisionBillingProfile(ctx, ns, customInvoicing.App.GetID(), + billingtest.WithProgressiveBilling(), + billingtest.WithCollectionInterval(datetime.MustParseDuration(s.T(), "PT1H")), + billingtest.WithManualApproval(), + ) + + createAt := datetime.MustParseTimeInLocation(s.T(), "2025-12-01T00:00:00Z", time.UTC).AsTime() + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + + clock.FreezeTime(createAt) + defer clock.UnFreeze() + + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromInt(30), + PaymentTerm: productcatalog.InAdvancePaymentTerm, + }), + name: "flat-fee-credit-only-delete", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: "flat-fee-credit-only-delete", + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + flatFeeChargeID, err := res[0].GetChargeID() + s.NoError(err) + + clock.FreezeTime(servicePeriod.From) + + advancedCharges, err := s.Charges.AdvanceCharges(ctx, charges.AdvanceChargesInput{ + Customer: cust.GetID(), + }) + s.NoError(err) + s.Len(advancedCharges, 1) + + advancedCharge, err := advancedCharges[0].AsFlatFeeCharge() + s.NoError(err) + s.Equal(meta.ChargeStatusFinal, advancedCharge.Status) + s.Len(advancedCharge.State.CreditRealizations, 1) + + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.NewFromInt(-30))) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.NewFromInt(30))) + + err = s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: cust.GetID(), + PatchesByChargeID: map[string]charges.Patch{ + flatFeeChargeID.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) + + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) +} + +func (s *CreditsTestSuite) TestUsageBasedCreditOnlyDeleteCorrectionSanity() { + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-sanity-usagebased-credit-only-delete") + + cust := s.createLedgerBackedCustomer(ns, "test-subject") + sandboxApp := s.InstallSandboxApp(s.T(), ns) + _ = s.ProvisionBillingProfile(ctx, ns, sandboxApp.GetID()) + + apiRequestsTotal := s.SetupApiRequestsTotalFeature(ctx, ns) + + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + createAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-03T00:00:00Z", time.UTC).AsTime() + + clock.FreezeTime(createAt) + defer clock.UnFreeze() + + s.MockStreamingConnector.AddSimpleEvent( + apiRequestsTotal.Feature.Key, + 8, + datetime.MustParseTimeInLocation(s.T(), "2026-01-15T00:00:00Z", time.UTC).AsTime(), + ) + + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Amount: alpacadecimal.NewFromInt(1), + }), + name: "usage-based-credit-only-delete", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: "usage-based-credit-only-delete", + featureKey: apiRequestsTotal.Feature.Key, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + usageBasedCharge, err := res[0].AsUsageBasedCharge() + s.NoError(err) + s.Equal(meta.ChargeStatusFinal, meta.ChargeStatus(usageBasedCharge.Status)) + s.Len(usageBasedCharge.Realizations, 1) + s.Len(usageBasedCharge.Realizations[0].CreditsAllocated, 1) + s.True(usageBasedCharge.Realizations[0].CreditsAllocated[0].Amount.Equal(alpacadecimal.NewFromInt(8))) + + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.NewFromInt(-8))) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.NewFromInt(8))) + + err = s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: cust.GetID(), + PatchesByChargeID: map[string]charges.Patch{ + usageBasedCharge.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) +} + +func (s *CreditsTestSuite) TestUsageBasedCreditOnlyDeleteCorrectionWithPartialBackfillSanity() { + ctx := context.Background() + ns := s.GetUniqueNamespace("charges-sanity-usagebased-credit-only-delete-partial-backfill") + + cust := s.createLedgerBackedCustomer(ns, "test-subject") + sandboxApp := s.InstallSandboxApp(s.T(), ns) + _ = s.ProvisionBillingProfile(ctx, ns, sandboxApp.GetID()) + + apiRequestsTotal := s.SetupApiRequestsTotalFeature(ctx, ns) + + servicePeriod := timeutil.ClosedPeriod{ + From: datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime(), + To: datetime.MustParseTimeInLocation(s.T(), "2026-02-01T00:00:00Z", time.UTC).AsTime(), + } + createAt := datetime.MustParseTimeInLocation(s.T(), "2026-02-03T00:00:00Z", time.UTC).AsTime() + + clock.FreezeTime(createAt) + defer clock.UnFreeze() + + // Given a usage-based credit-only charge that is created after the service period, so it + // finalizes immediately with 50 units of unattributed advance-backed usage. + s.MockStreamingConnector.AddSimpleEvent( + apiRequestsTotal.Feature.Key, + 50, + datetime.MustParseTimeInLocation(s.T(), "2026-01-15T00:00:00Z", time.UTC).AsTime(), + ) + + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.createMockChargeIntent(createMockChargeIntentInput{ + customer: cust.GetID(), + currency: USD, + servicePeriod: servicePeriod, + settlementMode: productcatalog.CreditOnlySettlementMode, + price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Amount: alpacadecimal.NewFromInt(1), + }), + name: "usage-based-credit-only-delete-partial-backfill", + managedBy: billing.SubscriptionManagedLine, + uniqueReferenceID: "usage-based-credit-only-delete-partial-backfill", + featureKey: apiRequestsTotal.Feature.Key, + }), + }, + }) + s.NoError(err) + s.Len(res, 1) + + usageBasedCharge, err := res[0].AsUsageBasedCharge() + s.NoError(err) + s.Equal(meta.ChargeStatusFinal, meta.ChargeStatus(usageBasedCharge.Status)) + s.Len(usageBasedCharge.Realizations, 1) + s.Len(usageBasedCharge.Realizations[0].CreditsAllocated, 1) + allocatedAmount := usageBasedCharge.Realizations[0].CreditsAllocated[0].Amount + purchaseAmount := alpacadecimal.NewFromInt(20) + remainingUncovered := allocatedAmount.Sub(purchaseAmount) + + // Then the full amount sits on the nil-cost-basis receivable/accrued path. + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(allocatedAmount.Neg())) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(allocatedAmount)) + + creditPurchaseIntent := s.createCreditPurchaseIntent(createCreditPurchaseIntentInput{ + customer: cust.GetID(), + currency: USD, + amount: purchaseAmount, + servicePeriod: timeutil.ClosedPeriod{ + From: createAt, + To: createAt, + }, + settlement: creditpurchase.NewSettlement(creditpurchase.ExternalSettlement{ + GenericSettlement: creditpurchase.GenericSettlement{ + Currency: USD, + CostBasis: alpacadecimal.NewFromFloat(0.5), + }, + InitialStatus: creditpurchase.CreatedInitialPaymentSettlementStatus, + }), + }) + + // When a later external credit purchase backfills part of that earlier advance-backed usage. + creditPurchaseRes, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + creditPurchaseIntent, + }, + }) + s.NoError(err) + s.Len(creditPurchaseRes, 1) + creditPurchaseCharge, err := creditPurchaseRes[0].AsCreditPurchaseCharge() + s.NoError(err) + + costBasis := alpacadecimal.NewFromFloat(0.5) + backingGroup, err := s.Ledger.GetTransactionGroup(ctx, models.NamespacedID{ + Namespace: ns, + ID: creditPurchaseCharge.State.CreditGrantRealization.TransactionGroupID, + }) + s.NoError(err) + s.Len(backingGroup.Transactions(), 2) + + // Then only the purchased portion moves onto the purchased-credit cost-basis route. + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(allocatedAmount.Neg())) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(remainingUncovered)) + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).Equal(purchaseAmount.Neg())) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis)).Equal(purchaseAmount)) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis)).Equal(alpacadecimal.Zero)) + + // When the original charge is deleted with refund-as-credits. + err = s.Charges.ApplyPatches(ctx, charges.ApplyPatchesInput{ + CustomerID: cust.GetID(), + PatchesByChargeID: map[string]charges.Patch{ + usageBasedCharge.ID: meta.NewPatchDelete(meta.RefundAsCreditsDeletePolicy), + }, + }) + s.NoError(err) + + // Then the purchased part is returned as available credit and the original accrued usage is cleared. + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen).Equal(purchaseAmount.Neg())) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&costBasis)).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis)).Equal(purchaseAmount)) + s.True(s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).Equal(purchaseAmount.Neg())) +} + func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { ctx := context.Background() ns := s.GetUniqueNamespace("charges-sanity-test") @@ -141,8 +424,8 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { // Validate balances zeroCostBasis := alpacadecimal.Zero purchasedCostBasis := alpacadecimal.NewFromFloat(0.5) - s.Equal(float64(30), s.mustCustomerFBOBalance(cust.GetID(), USD, &zeroCostBasis).InexactFloat64()) - s.Equal(float64(0), s.mustCustomerFBOBalance(cust.GetID(), USD, &purchasedCostBasis).InexactFloat64()) + s.Equal(float64(30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&zeroCostBasis)).InexactFloat64()) + s.Equal(float64(0), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&purchasedCostBasis)).InexactFloat64()) }) var externalCreditPurchaseChargeID meta.ChargeID @@ -186,8 +469,8 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { // Validate balances costBasis := alpacadecimal.NewFromFloat(0.5) - s.Equal(float64(50), s.mustCustomerFBOBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) - s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(50), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis)).InexactFloat64()) + s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) externalCreditPurchaseChargeID = cpCharge.GetChargeID() }) @@ -204,7 +487,7 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { costBasis := alpacadecimal.NewFromFloat(0.5) s.Equal(payment.StatusAuthorized, updatedCharge.State.ExternalPaymentSettlement.Status) - s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) }) s.Run("the customer settles the credit purchase payment", func() { @@ -219,7 +502,7 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { costBasis := alpacadecimal.NewFromFloat(0.5) s.Equal(payment.StatusSettled, updatedCharge.State.ExternalPaymentSettlement.Status) - s.Equal(float64(0), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(0), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) }) // TOTAL credits balance: 30 + 50 = 80 USD @@ -240,15 +523,15 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { earnings alpacadecimal.Decimal } flatFeeStart := flatFeeLedgerSnapshot{ - promoFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis), - externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis), - promoReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis), - externalReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis), - totalOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, nil), - accrued: s.mustCustomerAccruedBalance(cust.GetID(), USD), - authorizedReceivable: s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil), - totalWash: s.mustWashBalance(ns, USD, nil), - externalWash: s.mustWashBalance(ns, USD, &externalCostBasis), + promoFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis)), + externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)), + promoReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen), + externalReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen), + totalOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen), + accrued: s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]()), + authorizedReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized), + totalWash: s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]()), + externalWash: s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis)), earnings: s.mustEarningsBalance(ns, USD), } assertDelta := func(label string, start, delta, actual alpacadecimal.Decimal) { @@ -327,15 +610,15 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { customerCreditRealization := updatedFlatFeeCharge.State.CreditRealizations[1] s.Equal(float64(50), customerCreditRealization.Amount.InexactFloat64()) - assertDelta("promo FBO after invoice assignment", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external FBO after invoice assignment", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("promo receivable after invoice assignment", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external receivable after invoice assignment", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("total open receivable after invoice assignment", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("accrued after invoice assignment", flatFeeStart.accrued, alpacadecimal.NewFromInt(80), s.mustCustomerAccruedBalance(cust.GetID(), USD)) - assertDelta("authorized receivable after invoice assignment", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("total wash after invoice assignment", flatFeeStart.totalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, nil)) - assertDelta("external wash after invoice assignment", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, &externalCostBasis)) + assertDelta("promo FBO after invoice assignment", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis))) + assertDelta("external FBO after invoice assignment", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis))) + assertDelta("promo receivable after invoice assignment", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("external receivable after invoice assignment", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("total open receivable after invoice assignment", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("accrued after invoice assignment", flatFeeStart.accrued, alpacadecimal.NewFromInt(80), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("authorized receivable after invoice assignment", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized)) + assertDelta("total wash after invoice assignment", flatFeeStart.totalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("external wash after invoice assignment", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis))) assertDelta("earnings after invoice assignment", flatFeeStart.earnings, alpacadecimal.Zero, s.mustEarningsBalance(ns, USD)) stdInvoiceID = invoice.GetInvoiceID() @@ -366,16 +649,16 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { s.Equal(float64(20), accruedUsage.Totals.Total.InexactFloat64(), "totals should be the same as the input") s.Equal(float64(80), accruedUsage.Totals.CreditsTotal.InexactFloat64(), "totals should be the same as the input") - assertDelta("promo FBO after payment authorization", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external FBO after payment authorization", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("promo receivable after payment authorization", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external receivable after payment authorization", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("total open receivable after payment authorization", flatFeeStart.totalOpenReceivable, alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("authorized receivable after payment authorization", flatFeeStart.authorizedReceivable, alpacadecimal.NewFromInt(20), s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("accrued after payment authorization", flatFeeStart.accrued, alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD)) - assertDelta("total wash after payment authorization", flatFeeStart.totalWash, alpacadecimal.NewFromInt(-20), s.mustWashBalance(ns, USD, nil)) - assertDelta("external wash after payment authorization", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, &externalCostBasis)) - assertDelta("earnings after payment authorization", flatFeeStart.earnings, alpacadecimal.NewFromInt(100), s.mustEarningsBalance(ns, USD)) + assertDelta("promo FBO after payment authorization", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis))) + assertDelta("external FBO after payment authorization", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis))) + assertDelta("promo receivable after payment authorization", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("external receivable after payment authorization", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("total open receivable after payment authorization", flatFeeStart.totalOpenReceivable, alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("authorized receivable after payment authorization", flatFeeStart.authorizedReceivable, alpacadecimal.NewFromInt(20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized)) + assertDelta("accrued after payment authorization", flatFeeStart.accrued, alpacadecimal.NewFromInt(100), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("total wash after payment authorization", flatFeeStart.totalWash, alpacadecimal.NewFromInt(-20), s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("external wash after payment authorization", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis))) + assertDelta("earnings after payment authorization", flatFeeStart.earnings, alpacadecimal.Zero, s.mustEarningsBalance(ns, USD)) }) s.Run("payment is settled", func() { @@ -394,14 +677,14 @@ func (s *CreditsTestSuite) TestFlatFeeCreditThenInvoiceSanity() { s.NoError(err) s.Equal(meta.ChargeStatusFinal, updatedFlatFeeCharge.Status) - assertDelta("promo receivable after payment settlement", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external receivable after payment settlement", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("total open receivable after payment settlement", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("authorized receivable after payment settlement", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("accrued after payment settlement", flatFeeStart.accrued, alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD)) - assertDelta("total wash after payment settlement", flatFeeStart.totalWash, alpacadecimal.NewFromInt(-20), s.mustWashBalance(ns, USD, nil)) - assertDelta("external wash after payment settlement", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, &externalCostBasis)) - assertDelta("earnings after payment settlement", flatFeeStart.earnings, alpacadecimal.NewFromInt(100), s.mustEarningsBalance(ns, USD)) + assertDelta("promo receivable after payment settlement", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("external receivable after payment settlement", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("total open receivable after payment settlement", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("authorized receivable after payment settlement", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized)) + assertDelta("accrued after payment settlement", flatFeeStart.accrued, alpacadecimal.NewFromInt(100), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("total wash after payment settlement", flatFeeStart.totalWash, alpacadecimal.NewFromInt(-20), s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("external wash after payment settlement", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis))) + assertDelta("earnings after payment settlement", flatFeeStart.earnings, alpacadecimal.Zero, s.mustEarningsBalance(ns, USD)) }) } @@ -442,8 +725,8 @@ func (s *CreditsTestSuite) TestCreditPurchasePersistsPriority() { s.Equal(&priority, fetchedCharge.Intent.Priority) zeroCostBasis := alpacadecimal.Zero - s.True(s.mustCustomerFBOBalanceWithPriority(cust.GetID(), USD, &zeroCostBasis, priority).Equal(alpacadecimal.NewFromInt(25))) - s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, &zeroCostBasis).Equal(alpacadecimal.Zero)) + s.True(s.mustCustomerFBOBalanceWithPriority(cust.GetID(), USD, mo.Some(&zeroCostBasis), priority).Equal(alpacadecimal.NewFromInt(25))) + s.True(s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&zeroCostBasis)).Equal(alpacadecimal.Zero)) } func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { @@ -502,8 +785,8 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { zeroCostBasis := alpacadecimal.Zero purchasedCostBasis := alpacadecimal.NewFromFloat(0.5) - s.Equal(float64(30), s.mustCustomerFBOBalance(cust.GetID(), USD, &zeroCostBasis).InexactFloat64()) - s.Equal(float64(0), s.mustCustomerFBOBalance(cust.GetID(), USD, &purchasedCostBasis).InexactFloat64()) + s.Equal(float64(30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&zeroCostBasis)).InexactFloat64()) + s.Equal(float64(0), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&purchasedCostBasis)).InexactFloat64()) }) var externalCreditPurchaseChargeID meta.ChargeID @@ -540,8 +823,8 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { s.NotEmpty(cpCharge.State.CreditGrantRealization.TransactionGroupID) costBasis := alpacadecimal.NewFromFloat(0.5) - s.Equal(float64(50), s.mustCustomerFBOBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) - s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(50), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&costBasis)).InexactFloat64()) + s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) externalCreditPurchaseChargeID = cpCharge.GetChargeID() }) @@ -555,7 +838,7 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { costBasis := alpacadecimal.NewFromFloat(0.5) s.Equal(payment.StatusAuthorized, updatedCharge.State.ExternalPaymentSettlement.Status) - s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(-50), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) }) s.Run("the customer settles the credit purchase payment", func() { @@ -567,7 +850,7 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { costBasis := alpacadecimal.NewFromFloat(0.5) s.Equal(payment.StatusSettled, updatedCharge.State.ExternalPaymentSettlement.Status) - s.Equal(float64(0), s.mustCustomerReceivableBalance(cust.GetID(), USD, &costBasis).InexactFloat64()) + s.Equal(float64(0), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen).InexactFloat64()) }) var flatFeeChargeID meta.ChargeID @@ -587,16 +870,16 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { earnings alpacadecimal.Decimal } flatFeeStart := flatFeeLedgerSnapshot{ - promoFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis), - externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis), - unknownFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, nil), - promoReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis), - externalReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis), - totalOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, nil), - accrued: s.mustCustomerAccruedBalance(cust.GetID(), USD), - authorizedReceivable: s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil), - totalWash: s.mustWashBalance(ns, USD, nil), - externalWash: s.mustWashBalance(ns, USD, &externalCostBasis), + promoFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis)), + externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)), + unknownFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)), + promoReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen), + externalReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen), + totalOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen), + accrued: s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]()), + authorizedReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized), + totalWash: s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]()), + externalWash: s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis)), earnings: s.mustEarningsBalance(ns, USD), } assertDelta := func(label string, start, delta, actual alpacadecimal.Decimal) { @@ -645,12 +928,12 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { // Credit-only flat fees bypass invoice creation and are only allocated once the charge advances at InvoiceAt, // so creating the charge early should leave every ledger bucket untouched. - assertDelta("promo FBO after credit-only create", flatFeeStart.promoFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external FBO after credit-only create", flatFeeStart.externalFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("unknown FBO after credit-only create", flatFeeStart.unknownFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, nil)) - assertDelta("authorized receivable after credit-only create", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("total open receivable after credit-only create", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("accrued after credit-only create", flatFeeStart.accrued, alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD)) + assertDelta("promo FBO after credit-only create", flatFeeStart.promoFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis))) + assertDelta("external FBO after credit-only create", flatFeeStart.externalFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis))) + assertDelta("unknown FBO after credit-only create", flatFeeStart.unknownFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + assertDelta("authorized receivable after credit-only create", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized)) + assertDelta("total open receivable after credit-only create", flatFeeStart.totalOpenReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("accrued after credit-only create", flatFeeStart.accrued, alpacadecimal.Zero, s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]())) assertDelta("earnings after credit-only create", flatFeeStart.earnings, alpacadecimal.Zero, s.mustEarningsBalance(ns, USD)) }) @@ -689,24 +972,24 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { // - the uncovered remainder becomes open receivable immediately // - authorized receivable stays empty because no payment authorization happens // - wash and earnings stay unchanged because this flow never enters the invoice payment lifecycle - assertDelta("promo FBO after credit-only advance", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external FBO after credit-only advance", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("unknown FBO after credit-only advance", flatFeeStart.unknownFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, nil)) - assertDelta("promo receivable after credit-only advance", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &promoCostBasis)) - assertDelta("external receivable after credit-only advance", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, &externalCostBasis)) - assertDelta("total open receivable after credit-only advance", flatFeeStart.totalOpenReceivable, alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, nil)) - assertDelta("authorized receivable after credit-only advance", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerAuthorizedReceivableBalance(cust.GetID(), USD, nil)) + assertDelta("promo FBO after credit-only advance", flatFeeStart.promoFBO, alpacadecimal.NewFromInt(-30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&promoCostBasis))) + assertDelta("external FBO after credit-only advance", flatFeeStart.externalFBO, alpacadecimal.NewFromInt(-50), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis))) + assertDelta("unknown FBO after credit-only advance", flatFeeStart.unknownFBO, alpacadecimal.Zero, s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil))) + assertDelta("promo receivable after credit-only advance", flatFeeStart.promoReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&promoCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("external receivable after credit-only advance", flatFeeStart.externalReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("total open receivable after credit-only advance", flatFeeStart.totalOpenReceivable, alpacadecimal.NewFromInt(-20), s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusOpen)) + assertDelta("authorized receivable after credit-only advance", flatFeeStart.authorizedReceivable, alpacadecimal.Zero, s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal](), ledger.TransactionAuthorizationStatusAuthorized)) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.NewFromInt(-20)), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.NewFromInt(-20)), "the uncovered credit_only shortfall should live in the exact open advance receivable route", ) s.True( - s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, nil).Equal(alpacadecimal.NewFromInt(20)), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.NewFromInt(20)), "the uncovered shortfall should also remain in unattributed accrued until a later purchase backfills it", ) - assertDelta("accrued after credit-only advance", flatFeeStart.accrued, alpacadecimal.NewFromInt(100), s.mustCustomerAccruedBalance(cust.GetID(), USD)) - assertDelta("total wash after credit-only advance", flatFeeStart.totalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, nil)) - assertDelta("external wash after credit-only advance", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, &externalCostBasis)) + assertDelta("accrued after credit-only advance", flatFeeStart.accrued, alpacadecimal.NewFromInt(100), s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("total wash after credit-only advance", flatFeeStart.totalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.None[*alpacadecimal.Decimal]())) + assertDelta("external wash after credit-only advance", flatFeeStart.externalWash, alpacadecimal.Zero, s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis))) assertDelta("earnings after credit-only advance", flatFeeStart.earnings, alpacadecimal.Zero, s.mustEarningsBalance(ns, USD)) }) @@ -723,14 +1006,14 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { } start := backfillSnapshot{ - externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis), - externalOpenReceivable: s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, &externalCostBasis, ledger.TransactionAuthorizationStatusOpen), - advanceOpenReceivable: s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusOpen), - advanceAuthorized: s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusAuthorized), - externalAccrued: s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, &externalCostBasis), - unattributedAccrued: s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, nil), - totalAccrued: s.mustCustomerAccruedBalance(cust.GetID(), USD), - externalWash: s.mustWashBalance(ns, USD, &externalCostBasis), + externalFBO: s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)), + externalOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen), + advanceOpenReceivable: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen), + advanceAuthorized: s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusAuthorized), + externalAccrued: s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)), + unattributedAccrued: s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)), + totalAccrued: s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]()), + externalWash: s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis)), } const laterPurchaseAmount = 50 @@ -770,21 +1053,21 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { // - the prior advance receivable is re-attributed into the purchased cost-basis bucket // - unattributed accrued is translated into the purchased cost-basis bucket // - only the remainder becomes newly issued purchased credit - assertDelta("external FBO after later purchase initiation", start.externalFBO, alpacadecimal.NewFromInt(30), s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis)) + assertDelta("external FBO after later purchase initiation", start.externalFBO, alpacadecimal.NewFromInt(30), s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis))) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, &externalCostBasis, ledger.TransactionAuthorizationStatusOpen).Equal(start.externalOpenReceivable.Sub(alpacadecimal.NewFromInt(50))), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen).Equal(start.externalOpenReceivable.Sub(alpacadecimal.NewFromInt(50))), "the purchased cost-basis open receivable should now represent the full purchase amount", ) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), "the prior advance receivable should be fully re-attributed out of the nil cost-basis bucket at initiation", ) s.True( - s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, nil).Equal(alpacadecimal.Zero), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero), "the unattributed accrued bucket should be translated immediately during attribution", ) s.True( - s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, &externalCostBasis).Equal(start.externalAccrued.Add(alpacadecimal.NewFromInt(20))), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)).Equal(start.externalAccrued.Add(alpacadecimal.NewFromInt(20))), "the backfilled portion should already be visible in the purchased cost-basis accrued bucket after initiation", ) @@ -797,11 +1080,11 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { // Authorization now only stages settlement funding; attribution already happened during purchase initiation. s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, &externalCostBasis, ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.NewFromInt(50)), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.NewFromInt(50)), "the purchased amount should be visible in the exact authorized receivable route before settlement", ) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusAuthorized).Equal(start.advanceAuthorized), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusAuthorized).Equal(start.advanceAuthorized), "the legacy advance route should still have no authorized staging", ) @@ -815,38 +1098,38 @@ func (s *CreditsTestSuite) TestFlatFeeCreditOnlySanity() { // Settlement is now just the normal authorized -> open move in the purchased cost-basis bucket. // The earlier attribution stays intact, and the purchased receivable fully nets out here. s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), "the exact open advance receivable bucket should stay cleared after initiation-time attribution", ) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, nil, ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.Zero), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil), ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.Zero), "the exact authorized advance bucket should stay empty", ) s.True( - s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, nil).Equal(alpacadecimal.Zero), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some[*alpacadecimal.Decimal](nil)).Equal(alpacadecimal.Zero), "the unattributed accrued bucket should remain empty after initiation-time translation", ) s.True( - s.mustCustomerAccruedBalanceWithCostBasis(cust.GetID(), USD, &externalCostBasis).Equal(start.externalAccrued.Add(alpacadecimal.NewFromInt(20))), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)).Equal(start.externalAccrued.Add(alpacadecimal.NewFromInt(20))), "the backfilled portion should remain attributed in the purchased cost-basis bucket", ) s.True( - s.mustCustomerFBOBalance(cust.GetID(), USD, &externalCostBasis).Equal(start.externalFBO.Add(alpacadecimal.NewFromInt(30))), + s.mustCustomerFBOBalance(cust.GetID(), USD, mo.Some(&externalCostBasis)).Equal(start.externalFBO.Add(alpacadecimal.NewFromInt(30))), "only the purchase remainder should stay behind as newly available credit", ) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, &externalCostBasis, ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusOpen).Equal(alpacadecimal.Zero), "the purchased cost-basis receivable should net back to zero after settlement and advance funding", ) s.True( - s.mustCustomerReceivableRouteBalance(cust.GetID(), USD, &externalCostBasis, ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.Zero), + s.mustCustomerReceivableBalance(cust.GetID(), USD, mo.Some(&externalCostBasis), ledger.TransactionAuthorizationStatusAuthorized).Equal(alpacadecimal.Zero), "the purchased authorized receivable route should be cleared by settlement", ) s.True( - s.mustCustomerAccruedBalance(cust.GetID(), USD).Equal(start.totalAccrued), + s.mustCustomerAccruedBalance(cust.GetID(), USD, mo.None[*alpacadecimal.Decimal]()).Equal(start.totalAccrued), "settlement should only translate accrued between buckets, not change the total accrued amount", ) - assertDelta("external wash after later purchase settlement", start.externalWash, alpacadecimal.NewFromInt(-50), s.mustWashBalance(ns, USD, &externalCostBasis)) + assertDelta("external wash after later purchase settlement", start.externalWash, alpacadecimal.NewFromInt(-50), s.mustWashBalance(ns, USD, mo.Some(&externalCostBasis))) }) } @@ -954,30 +1237,37 @@ func (s *CreditsTestSuite) createLedgerBackedCustomer(ns string, subjectKey stri return cust } -func (s *CreditsTestSuite) mustCustomerFBOBalance(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal) alpacadecimal.Decimal { +// Use this helper for customer FBO balance in a currency. Pass mo.None() for +// all cost bases, mo.Some(nil) for the explicit nil-cost-basis route, or +// mo.Some(&costBasis) for one concrete cost-basis route. +func (s *CreditsTestSuite) mustCustomerFBOBalance(customerID customer.CustomerID, code currencyx.Code, costBasis mo.Option[*alpacadecimal.Decimal]) alpacadecimal.Decimal { return s.mustCustomerFBOBalanceWithPriority(customerID, code, costBasis, ledger.DefaultCustomerFBOPriority) } -func (s *CreditsTestSuite) mustCustomerFBOBalanceWithPriority(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal, priority int) alpacadecimal.Decimal { +// Use this helper for customer FBO balance in a currency when the test also +// needs to filter by a specific credit priority. Pass mo.None() for all cost +// bases, mo.Some(nil) for the explicit nil-cost-basis route, or +// mo.Some(&costBasis) for one concrete cost-basis route. +func (s *CreditsTestSuite) mustCustomerFBOBalanceWithPriority(customerID customer.CustomerID, code currencyx.Code, costBasis mo.Option[*alpacadecimal.Decimal], priority int) alpacadecimal.Decimal { s.T().Helper() customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) s.NoError(err) - subAccount, err := customerAccounts.FBOAccount.GetSubAccountForRoute(s.T().Context(), ledger.CustomerFBORouteParams{ + balance, err := customerAccounts.FBOAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, CostBasis: costBasis, - CreditPriority: priority, + CreditPriority: lo.ToPtr(priority), }) s.NoError(err) - balance, err := subAccount.GetBalance(s.T().Context()) - s.NoError(err) - return balance.Settled() } -func (s *CreditsTestSuite) mustCustomerReceivableBalance(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal) alpacadecimal.Decimal { +// Use this helper for customer receivable balance in a currency and one +// authorization state. Pass mo.None() for all cost bases, mo.Some(nil) for the +// explicit nil-cost-basis route, or mo.Some(&costBasis) for one concrete route. +func (s *CreditsTestSuite) mustCustomerReceivableBalance(customerID customer.CustomerID, code currencyx.Code, costBasis mo.Option[*alpacadecimal.Decimal], status ledger.TransactionAuthorizationStatus) alpacadecimal.Decimal { s.T().Helper() customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) @@ -985,82 +1275,36 @@ func (s *CreditsTestSuite) mustCustomerReceivableBalance(customerID customer.Cus balance, err := customerAccounts.ReceivableAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, - CostBasis: routeFilterCostBasis(costBasis), - TransactionAuthorizationStatus: lo.ToPtr(ledger.TransactionAuthorizationStatusOpen), - }) - s.NoError(err) - - return balance.Settled() -} - -func (s *CreditsTestSuite) mustCustomerAuthorizedReceivableBalance(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal) alpacadecimal.Decimal { - s.T().Helper() - - customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) - s.NoError(err) - - balance, err := customerAccounts.ReceivableAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ - Currency: code, - CostBasis: routeFilterCostBasis(costBasis), - TransactionAuthorizationStatus: lo.ToPtr(ledger.TransactionAuthorizationStatusAuthorized), + CostBasis: costBasis, + TransactionAuthorizationStatus: lo.ToPtr(status), }) s.NoError(err) return balance.Settled() } -func (s *CreditsTestSuite) mustCustomerAccruedBalance(customerID customer.CustomerID, code currencyx.Code) alpacadecimal.Decimal { +// Use this helper for customer accrued balance in a currency. Pass mo.None() for +// all cost bases, mo.Some(nil) for the explicit nil-cost-basis route, or +// mo.Some(&costBasis) for one concrete cost-basis route. +func (s *CreditsTestSuite) mustCustomerAccruedBalance(customerID customer.CustomerID, code currencyx.Code, costBasis mo.Option[*alpacadecimal.Decimal]) alpacadecimal.Decimal { s.T().Helper() customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) s.NoError(err) balance, err := customerAccounts.AccruedAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ - Currency: code, - }) - s.NoError(err) - - return balance.Settled() -} - -func (s *CreditsTestSuite) mustCustomerAccruedBalanceWithCostBasis(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal) alpacadecimal.Decimal { - s.T().Helper() - - customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) - s.NoError(err) - - subAccount, err := customerAccounts.AccruedAccount.GetSubAccountForRoute(s.T().Context(), ledger.CustomerAccruedRouteParams{ Currency: code, CostBasis: costBasis, }) s.NoError(err) - balance, err := subAccount.GetBalance(s.T().Context()) - s.NoError(err) - - return balance.Settled() -} - -func (s *CreditsTestSuite) mustCustomerReceivableRouteBalance(customerID customer.CustomerID, code currencyx.Code, costBasis *alpacadecimal.Decimal, status ledger.TransactionAuthorizationStatus) alpacadecimal.Decimal { - s.T().Helper() - - customerAccounts, err := s.LedgerResolver.GetCustomerAccounts(s.T().Context(), customerID) - s.NoError(err) - - subAccount, err := customerAccounts.ReceivableAccount.GetSubAccountForRoute(s.T().Context(), ledger.CustomerReceivableRouteParams{ - Currency: code, - CostBasis: costBasis, - TransactionAuthorizationStatus: status, - }) - s.NoError(err) - - balance, err := subAccount.GetBalance(s.T().Context()) - s.NoError(err) - return balance.Settled() } -func (s *CreditsTestSuite) mustWashBalance(namespace string, code currencyx.Code, costBasis *alpacadecimal.Decimal) alpacadecimal.Decimal { +// Use this helper for aggregate wash balance in a currency. Pass mo.None() for +// all cost bases, mo.Some(nil) for the explicit nil-cost-basis route, or +// mo.Some(&costBasis) for one concrete cost-basis route. +func (s *CreditsTestSuite) mustWashBalance(namespace string, code currencyx.Code, costBasis mo.Option[*alpacadecimal.Decimal]) alpacadecimal.Decimal { s.T().Helper() businessAccounts, err := s.LedgerResolver.GetBusinessAccounts(s.T().Context(), namespace) @@ -1068,7 +1312,7 @@ func (s *CreditsTestSuite) mustWashBalance(namespace string, code currencyx.Code balance, err := businessAccounts.WashAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, - CostBasis: routeFilterCostBasis(costBasis), + CostBasis: costBasis, }) s.NoError(err) @@ -1099,14 +1343,6 @@ func (s *CreditsTestSuite) mustGetChargeByID(chargeID meta.ChargeID) charges.Cha return charge } -func routeFilterCostBasis(costBasis *alpacadecimal.Decimal) mo.Option[*alpacadecimal.Decimal] { - if costBasis == nil { - return mo.None[*alpacadecimal.Decimal]() - } - - return mo.Some(costBasis) -} - type createCreditPurchaseIntentInput struct { customer customer.CustomerID currency currencyx.Code diff --git a/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.down.sql b/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.down.sql new file mode 100644 index 0000000000..34c487e4f8 --- /dev/null +++ b/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.down.sql @@ -0,0 +1,20 @@ +-- reverse: create index "creditrealizationlineagesegment_lineage_id_closed_at" to table: "credit_realization_lineage_segments" +DROP INDEX "creditrealizationlineagesegment_lineage_id_closed_at"; +-- reverse: create index "creditrealizationlineagesegment_lineage_id" to table: "credit_realization_lineage_segments" +DROP INDEX "creditrealizationlineagesegment_lineage_id"; +-- reverse: create index "creditrealizationlineagesegment_id" to table: "credit_realization_lineage_segments" +DROP INDEX "creditrealizationlineagesegment_id"; +-- reverse: create "credit_realization_lineage_segments" table +DROP TABLE "credit_realization_lineage_segments"; +-- reverse: create index "creditrealizationlineage_namespace_charge_id" to table: "credit_realization_lineages" +DROP INDEX "creditrealizationlineage_namespace_charge_id"; +-- reverse: create index "creditrealizationlineage_namespace_root_realization_id" to table: "credit_realization_lineages" +DROP INDEX "creditrealizationlineage_namespace_root_realization_id"; +-- reverse: create index "creditrealizationlineage_namespace_customer_id" to table: "credit_realization_lineages" +DROP INDEX "creditrealizationlineage_namespace_customer_id"; +-- reverse: create index "creditrealizationlineage_namespace" to table: "credit_realization_lineages" +DROP INDEX "creditrealizationlineage_namespace"; +-- reverse: create index "creditrealizationlineage_id" to table: "credit_realization_lineages" +DROP INDEX "creditrealizationlineage_id"; +-- reverse: create "credit_realization_lineages" table +DROP TABLE "credit_realization_lineages"; diff --git a/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.up.sql b/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.up.sql new file mode 100644 index 0000000000..35d2118e84 --- /dev/null +++ b/tools/migrate/migrations/20260409130630_add_credit_realization_lineage.up.sql @@ -0,0 +1,41 @@ +-- create "credit_realization_lineages" table +CREATE TABLE "credit_realization_lineages" ( + "id" character(26) NOT NULL, + "namespace" character varying NOT NULL, + "charge_id" character(26) NOT NULL, + "root_realization_id" character(26) NOT NULL, + "customer_id" character(26) NOT NULL, + "currency" character varying(3) NOT NULL, + "origin_kind" character varying NOT NULL, + "created_at" timestamptz NOT NULL, + PRIMARY KEY ("id"), + CONSTRAINT "credit_realization_lineages_charges_credit_realization_lineages" FOREIGN KEY ("charge_id") REFERENCES "charges" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION +); +-- create index "creditrealizationlineage_id" to table: "credit_realization_lineages" +CREATE UNIQUE INDEX "creditrealizationlineage_id" ON "credit_realization_lineages" ("id"); +-- create index "creditrealizationlineage_namespace" to table: "credit_realization_lineages" +CREATE INDEX "creditrealizationlineage_namespace" ON "credit_realization_lineages" ("namespace"); +-- create index "creditrealizationlineage_namespace_charge_id" to table: "credit_realization_lineages" +CREATE INDEX "creditrealizationlineage_namespace_charge_id" ON "credit_realization_lineages" ("namespace", "charge_id"); +-- create index "creditrealizationlineage_namespace_customer_id" to table: "credit_realization_lineages" +CREATE INDEX "creditrealizationlineage_namespace_customer_id" ON "credit_realization_lineages" ("namespace", "customer_id"); +-- create index "creditrealizationlineage_namespace_root_realization_id" to table: "credit_realization_lineages" +CREATE UNIQUE INDEX "creditrealizationlineage_namespace_root_realization_id" ON "credit_realization_lineages" ("namespace", "root_realization_id"); +-- create "credit_realization_lineage_segments" table +CREATE TABLE "credit_realization_lineage_segments" ( + "id" character(26) NOT NULL, + "amount" numeric NOT NULL, + "state" character varying NOT NULL, + "backing_transaction_group_id" character(26) NULL, + "closed_at" timestamptz NULL, + "created_at" timestamptz NOT NULL, + "lineage_id" character(26) NOT NULL, + PRIMARY KEY ("id"), + CONSTRAINT "credit_realization_lineage_segments_credit_realization_lineages" FOREIGN KEY ("lineage_id") REFERENCES "credit_realization_lineages" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION +); +-- create index "creditrealizationlineagesegment_id" to table: "credit_realization_lineage_segments" +CREATE UNIQUE INDEX "creditrealizationlineagesegment_id" ON "credit_realization_lineage_segments" ("id"); +-- create index "creditrealizationlineagesegment_lineage_id" to table: "credit_realization_lineage_segments" +CREATE INDEX "creditrealizationlineagesegment_lineage_id" ON "credit_realization_lineage_segments" ("lineage_id"); +-- create index "creditrealizationlineagesegment_lineage_id_closed_at" to table: "credit_realization_lineage_segments" +CREATE INDEX "creditrealizationlineagesegment_lineage_id_closed_at" ON "credit_realization_lineage_segments" ("lineage_id", "closed_at"); diff --git a/tools/migrate/migrations/atlas.sum b/tools/migrate/migrations/atlas.sum index 4ea78b4e5f..fe20b6b214 100644 --- a/tools/migrate/migrations/atlas.sum +++ b/tools/migrate/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:OjjKKkdVUZy4I4+hrQ61Kkz8r8QxjwFTblYPM93xF9M= +h1:v95ktNOwDMm7gmAmKGu3JqGMH/roQgLPgDfy6eWO8JI= 20240826120919_init.up.sql h1:tc1V91/smlmaeJGQ8h+MzTEeFjjnrrFDbDAjOYJK91o= 20240903155435_entitlement-expired-index.up.sql h1:Hp8u5uckmLXc1cRvWU0AtVnnK8ShlpzZNp8pbiJLhac= 20240917172257_billing-entities.up.sql h1:Q1dAMo0Vjiit76OybClNfYPGC5nmvov2/M2W1ioi4Kw= @@ -177,3 +177,4 @@ h1:OjjKKkdVUZy4I4+hrQ61Kkz8r8QxjwFTblYPM93xF9M= 20260403090155_charges_feature_ids.up.sql h1:fc/bLQ1teUmFSIsB2547wYeV5ge8hQwFdWh1poxMx60= 20260408082246_add_annotations_to_taxcode.up.sql h1:OFVNAijukvUz4nfnoxzbkL2F0myiNFNZchHBkNKWDkU= 20260409112434_invoice-line-backend.up.sql h1:IAB7hvWq+fw30Y6fUAFrT0EndZ/k4UhxVzq6OB/nt4o= +20260409130630_add_credit_realization_lineage.up.sql h1:6dB7wcaBcJEilSoVByeWxQBBLf7XVBAT7BQ7My/pXVc=