diff --git a/api/v3/handlers/customers/credits/get_balance.go b/api/v3/handlers/customers/credits/get_balance.go new file mode 100644 index 0000000000..7f8429f065 --- /dev/null +++ b/api/v3/handlers/customers/credits/get_balance.go @@ -0,0 +1,123 @@ +package customerscredits + +import ( + "context" + "errors" + "net/http" + + api "github.com/openmeterio/openmeter/api/v3" + "github.com/openmeterio/openmeter/api/v3/apierrors" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/commonhttp" + "github.com/openmeterio/openmeter/pkg/framework/transport/httptransport" + "github.com/openmeterio/openmeter/pkg/models" +) + +var errUnsupportedFeatureFilter = errors.New("feature filter is not supported for this balance endpoint") + +type ( + GetCustomerCreditBalanceRequest struct { + CustomerID customer.CustomerID + Currencies customerbalance.CurrencyFilter + } + GetCustomerCreditBalanceResponse = api.BillingCreditBalances + GetCustomerCreditBalanceParams struct { + CustomerID api.ULID + Params api.GetCustomerCreditBalanceParams + } + GetCustomerCreditBalanceHandler httptransport.HandlerWithArgs[GetCustomerCreditBalanceRequest, GetCustomerCreditBalanceResponse, GetCustomerCreditBalanceParams] +) + +func (h *handler) GetCustomerCreditBalance() GetCustomerCreditBalanceHandler { + return httptransport.NewHandlerWithArgs( + func(ctx context.Context, r *http.Request, args GetCustomerCreditBalanceParams) (GetCustomerCreditBalanceRequest, error) { + namespace, err := h.resolveNamespace(ctx) + if err != nil { + return GetCustomerCreditBalanceRequest{}, err + } + + if args.Params.Filter != nil && args.Params.Filter.Feature != nil { + return GetCustomerCreditBalanceRequest{}, apierrors.NewBadRequestError( + ctx, + models.NewGenericValidationError(errUnsupportedFeatureFilter), + apierrors.InvalidParameters{ + { + Field: "filter.feature", + Reason: errUnsupportedFeatureFilter.Error(), + Source: apierrors.InvalidParamSourceQuery, + }, + }, + ) + } + + request := GetCustomerCreditBalanceRequest{ + CustomerID: customer.CustomerID{ + Namespace: namespace, + ID: args.CustomerID, + }, + } + + if args.Params.Filter != nil && args.Params.Filter.Currency != nil { + currency := currencyx.Code(*args.Params.Filter.Currency) + request.Currencies = customerbalance.CurrencyFilter{ + Codes: []currencyx.Code{currency}, + } + } + + return request, nil + }, + func(ctx context.Context, request GetCustomerCreditBalanceRequest) (GetCustomerCreditBalanceResponse, error) { + _, err := h.customerService.GetCustomer(ctx, customer.GetCustomerInput{ + CustomerID: &request.CustomerID, + }) + if models.IsGenericNotFoundError(err) { + return GetCustomerCreditBalanceResponse{}, apierrors.NewNotFoundError(ctx, err, "customer") + } + if err != nil { + return GetCustomerCreditBalanceResponse{}, err + } + + balancesByCurrency, err := h.balanceFacade.GetBalances(ctx, customerbalance.GetBalancesInput{ + CustomerID: request.CustomerID, + Currencies: request.Currencies, + }) + if err != nil { + return GetCustomerCreditBalanceResponse{}, err + } + + balances := make([]api.CreditBalance, 0, len(balancesByCurrency)) + for _, item := range balancesByCurrency { + if len(request.Currencies.Codes) == 0 && item.Balance.Settled().IsZero() && item.Balance.Pending().IsZero() { + continue + } + + balances = append(balances, mapBalance(item.Currency, item.Balance)) + } + + return GetCustomerCreditBalanceResponse{ + RetrievedAt: clock.Now(), + Balances: balances, + }, nil + }, + commonhttp.JSONResponseEncoderWithStatus[GetCustomerCreditBalanceResponse](http.StatusOK), + httptransport.AppendOptions( + h.options, + httptransport.WithOperationName("get-customer-credit-balance"), + httptransport.WithErrorEncoder(apierrors.GenericErrorEncoder()), + )..., + ) +} + +func mapBalance(currency currencyx.Code, balance ledger.Balance) api.CreditBalance { + // Temporary mapping while the v3 credit-balance schema still predates the + // customerbalance service's settled/live-pending semantics. + return api.CreditBalance{ + Currency: api.BillingCurrencyCode(currency), + Available: balance.Settled().String(), + Pending: balance.Pending().String(), + } +} diff --git a/api/v3/handlers/customers/credits/handler.go b/api/v3/handlers/customers/credits/handler.go new file mode 100644 index 0000000000..74c380fc35 --- /dev/null +++ b/api/v3/handlers/customers/credits/handler.go @@ -0,0 +1,38 @@ +package customerscredits + +import ( + "context" + + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + "github.com/openmeterio/openmeter/pkg/framework/transport/httptransport" +) + +type customerBalanceFacade interface { + GetBalances(ctx context.Context, input customerbalance.GetBalancesInput) ([]customerbalance.BalanceByCurrency, error) +} + +type Handler interface { + GetCustomerCreditBalance() GetCustomerCreditBalanceHandler +} + +type handler struct { + resolveNamespace func(ctx context.Context) (string, error) + customerService customer.Service + balanceFacade customerBalanceFacade + options []httptransport.HandlerOption +} + +func New( + resolveNamespace func(ctx context.Context) (string, error), + customerService customer.Service, + balanceFacade customerBalanceFacade, + options ...httptransport.HandlerOption, +) Handler { + return &handler{ + resolveNamespace: resolveNamespace, + customerService: customerService, + balanceFacade: balanceFacade, + options: options, + } +} diff --git a/api/v3/server/routes.go b/api/v3/server/routes.go index 207222b559..bd1d042f31 100644 --- a/api/v3/server/routes.go +++ b/api/v3/server/routes.go @@ -5,6 +5,7 @@ import ( api "github.com/openmeterio/openmeter/api/v3" currencieshandler "github.com/openmeterio/openmeter/api/v3/handlers/currencies" + customerscreditshandler "github.com/openmeterio/openmeter/api/v3/handlers/customers/credits" ) // Meters @@ -322,7 +323,15 @@ func (s *Server) DeletePlanAddon(w http.ResponseWriter, r *http.Request, planId var unimplemented = api.Unimplemented{} func (s *Server) GetCustomerCreditBalance(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.GetCustomerCreditBalanceParams) { - unimplemented.GetCustomerCreditBalance(w, r, customerId, params) + if s.customersCreditsHandler == nil { + unimplemented.GetCustomerCreditBalance(w, r, customerId, params) + return + } + + s.customersCreditsHandler.GetCustomerCreditBalance().With(customerscreditshandler.GetCustomerCreditBalanceParams{ + CustomerID: customerId, + Params: params, + }).ServeHTTP(w, r) } func (s *Server) ListCreditGrants(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.ListCreditGrantsParams) { diff --git a/api/v3/server/server.go b/api/v3/server/server.go index 24b1e7f5a9..ee2481ba4a 100644 --- a/api/v3/server/server.go +++ b/api/v3/server/server.go @@ -19,6 +19,7 @@ import ( currencieshandler "github.com/openmeterio/openmeter/api/v3/handlers/currencies" customershandler "github.com/openmeterio/openmeter/api/v3/handlers/customers" customersbillinghandler "github.com/openmeterio/openmeter/api/v3/handlers/customers/billing" + customerscreditshandler "github.com/openmeterio/openmeter/api/v3/handlers/customers/credits" customersentitlementhandler "github.com/openmeterio/openmeter/api/v3/handlers/customers/entitlementaccess" eventshandler "github.com/openmeterio/openmeter/api/v3/handlers/events" featurecosthandler "github.com/openmeterio/openmeter/api/v3/handlers/featurecost" @@ -29,6 +30,7 @@ import ( taxcodeshandler "github.com/openmeterio/openmeter/api/v3/handlers/taxcodes" "github.com/openmeterio/openmeter/api/v3/oasmiddleware" "github.com/openmeterio/openmeter/api/v3/render" + "github.com/openmeterio/openmeter/app/config" "github.com/openmeterio/openmeter/openmeter/app" appstripe "github.com/openmeterio/openmeter/openmeter/app/stripe" "github.com/openmeterio/openmeter/openmeter/billing" @@ -37,6 +39,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/entitlement" "github.com/openmeterio/openmeter/openmeter/ingest" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/namespace/namespacedriver" @@ -57,6 +60,7 @@ type Config struct { ErrorHandler errorsx.Handler Middlewares []server.MiddlewareFunc PostAuthMiddlewares []server.MiddlewareFunc + Credits config.CreditsConfiguration // services AppService app.Service @@ -66,6 +70,7 @@ type Config struct { StreamingConnector streaming.Connector IngestService ingest.Service CustomerService customer.Service + CustomerBalanceFacade *customerbalance.Facade EntitlementService entitlement.Service PlanService plan.Service PlanSubscriptionService plansubscription.PlanSubscriptionService @@ -162,6 +167,7 @@ type Server struct { llmcostHandler llmcosthandler.Handler customersHandler customershandler.Handler customersBillingHandler customersbillinghandler.Handler + customersCreditsHandler customerscreditshandler.Handler customersEntitlementHandler customersentitlementhandler.Handler metersHandler metershandler.Handler subscriptionsHandler subscriptionshandler.Handler @@ -207,6 +213,10 @@ func NewServer(config *Config) (*Server, error) { eventsHandler := eventshandler.New(resolveNamespace, config.IngestService, httptransport.WithErrorHandler(config.ErrorHandler)) customersHandler := customershandler.New(resolveNamespace, config.CustomerService, httptransport.WithErrorHandler(config.ErrorHandler)) customersBillingHandler := customersbillinghandler.New(resolveNamespace, config.BillingService, config.CustomerService, config.StripeService, httptransport.WithErrorHandler(config.ErrorHandler)) + var customersCreditsHandler customerscreditshandler.Handler + if config.CustomerBalanceFacade != nil && config.Credits.Enabled { + customersCreditsHandler = customerscreditshandler.New(resolveNamespace, config.CustomerService, config.CustomerBalanceFacade, httptransport.WithErrorHandler(config.ErrorHandler)) + } customersEntitlementHandler := customersentitlementhandler.New(resolveNamespace, config.CustomerService, config.EntitlementService, httptransport.WithErrorHandler(config.ErrorHandler)) metersHandler := metershandler.New(resolveNamespace, config.MeterService, config.StreamingConnector, config.CustomerService, httptransport.WithErrorHandler(config.ErrorHandler)) subscriptionsHandler := subscriptionshandler.New(resolveNamespace, config.CustomerService, config.PlanService, config.PlanSubscriptionService, config.SubscriptionService, httptransport.WithErrorHandler(config.ErrorHandler)) @@ -234,6 +244,7 @@ func NewServer(config *Config) (*Server, error) { llmcostHandler: llmcostH, customersHandler: customersHandler, customersBillingHandler: customersBillingHandler, + customersCreditsHandler: customersCreditsHandler, customersEntitlementHandler: customersEntitlementHandler, metersHandler: metersHandler, subscriptionsHandler: subscriptionsHandler, diff --git a/app/common/customerbalance.go b/app/common/customerbalance.go new file mode 100644 index 0000000000..20ff580687 --- /dev/null +++ b/app/common/customerbalance.go @@ -0,0 +1,185 @@ +package common + +import ( + "context" + "log/slog" + + "github.com/google/wire" + + "github.com/openmeterio/openmeter/openmeter/billing/charges" + chargeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/adapter" + "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" + flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + usagebasedadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased/adapter" + usagebasedservice "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased/service" + "github.com/openmeterio/openmeter/openmeter/billing/rating" + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + ledgerchargeadapter "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" + "github.com/openmeterio/openmeter/openmeter/streaming" + "github.com/openmeterio/openmeter/pkg/framework/lockr" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +var CustomerBalance = wire.NewSet( + NewCustomerBalanceService, + NewCustomerBalanceFacade, +) + +func NewCustomerBalanceService( + logger *slog.Logger, + db *entdb.Client, + locker *lockr.Locker, + historicalLedger ledger.Ledger, + accountResolver ledger.AccountResolver, + accountService ledgeraccount.Service, + billingRegistry BillingRegistry, + featureConnector feature.FeatureConnector, + ratingService rating.Service, + streamingConnector streaming.Connector, +) (*customerbalance.Service, error) { + metaAdapter, err := metaadapter.New(metaadapter.Config{ + Client: db, + Logger: logger, + }) + if err != nil { + return nil, err + } + + searchAdapter, err := chargeadapter.New(chargeadapter.Config{ + Client: db, + Logger: logger, + }) + if err != nil { + return nil, err + } + + flatFeeAdapter, err := flatfeeadapter.New(flatfeeadapter.Config{ + Client: db, + Logger: logger, + MetaAdapter: metaAdapter, + }) + if err != nil { + return nil, err + } + + flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ + Adapter: flatFeeAdapter, + Handler: ledgerchargeadapter.NewFlatFeeHandler(historicalLedger, accountResolver, accountService), + MetaAdapter: metaAdapter, + Locker: locker, + }) + if err != nil { + return nil, err + } + + usageAdapter, err := usagebasedadapter.New(usagebasedadapter.Config{ + Client: db, + Logger: logger, + MetaAdapter: metaAdapter, + }) + if err != nil { + return nil, err + } + + usageService, err := usagebasedservice.New(usagebasedservice.Config{ + Adapter: usageAdapter, + Handler: usagebased.UnimplementedHandler{}, + Locker: locker, + MetaAdapter: metaAdapter, + CustomerOverrideService: billingRegistry.Billing, + FeatureService: featureConnector, + RatingService: ratingService, + StreamingConnector: streamingConnector, + }) + if err != nil { + return nil, err + } + + return customerbalance.New(customerbalance.Config{ + AccountResolver: accountResolver, + SubAccountService: accountService, + ChargesService: customerBalanceChargeStore{search: searchAdapter, flatFeeService: flatFeeService, usageBasedService: usageService}, + UsageBasedService: usageService, + }) +} + +func NewCustomerBalanceFacade(service *customerbalance.Service) (*customerbalance.Facade, error) { + return customerbalance.NewFacade(service) +} + +type customerBalanceChargeStore struct { + search charges.ChargesSearchAdapter + flatFeeService flatfee.Service + usageBasedService usagebased.Service +} + +func (s customerBalanceChargeStore) ListCharges(ctx context.Context, input charges.ListChargesInput) (pagination.Result[charges.Charge], error) { + searchResult, err := s.search.ListCharges(ctx, input) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + flatFeeIDs := make([]string, 0, len(searchResult.Items)) + usageBasedIDs := make([]string, 0, len(searchResult.Items)) + + for _, item := range searchResult.Items { + switch item.Type { + case meta.ChargeTypeFlatFee: + flatFeeIDs = append(flatFeeIDs, item.ID.ID) + case meta.ChargeTypeUsageBased: + usageBasedIDs = append(usageBasedIDs, item.ID.ID) + } + } + + flatFeeCharges, err := s.flatFeeService.GetByIDs(ctx, flatfee.GetByIDsInput{ + Namespace: input.Namespace, + IDs: flatFeeIDs, + Expands: input.Expands, + }) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + usageBasedCharges, err := s.usageBasedService.GetByIDs(ctx, usagebased.GetByIDsInput{ + Namespace: input.Namespace, + IDs: usageBasedIDs, + Expands: input.Expands, + }) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + chargesByID := make(map[string]charges.Charge, len(flatFeeCharges)+len(usageBasedCharges)) + + for _, charge := range flatFeeCharges { + chargesByID[charge.ID] = charges.NewCharge(charge) + } + + for _, charge := range usageBasedCharges { + chargesByID[charge.ID] = charges.NewCharge(charge) + } + + items := make([]charges.Charge, 0, len(searchResult.Items)) + for _, item := range searchResult.Items { + charge, ok := chargesByID[item.ID.ID] + if !ok { + continue + } + + items = append(items, charge) + } + + return pagination.Result[charges.Charge]{ + Page: searchResult.Page, + TotalCount: searchResult.TotalCount, + Items: items, + }, nil +} diff --git a/app/common/ledger.go b/app/common/ledger.go index b7cf984dbd..414ddaeeda 100644 --- a/app/common/ledger.go +++ b/app/common/ledger.go @@ -19,6 +19,7 @@ import ( // LedgerStack is the full provider set for the ledger stack. // Callers must provide *entdb.Client and *lockr.Locker (e.g. via common.Lockr). var LedgerStack = wire.NewSet( + NewLedgerRoutingValidator, NewLedgerAccountRepo, NewLedgerHistoricalRepo, NewLedgerResolversRepo, @@ -26,11 +27,14 @@ var LedgerStack = wire.NewSet( NewLedgerAccountService, NewLedgerHistoricalLedger, NewLedgerResolversService, - NewLedgerRoutingValidator, wire.Bind(new(ledger.Ledger), new(*historical.Ledger)), wire.Bind(new(ledger.AccountResolver), new(*resolvers.AccountResolver)), ) +func NewLedgerRoutingValidator() ledger.RoutingValidator { + return routingrules.DefaultValidator +} + func NewLedgerAccountRepo(db *entdb.Client) ledgeraccount.Repo { return accountadapter.NewRepo(db) } @@ -68,10 +72,6 @@ func NewLedgerHistoricalLedger( return historical.NewLedger(repo, accountSvc, locker, routingValidator) } -func NewLedgerRoutingValidator() ledger.RoutingValidator { - return routingrules.DefaultValidator -} - func NewLedgerResolversService( accountSvc ledgeraccount.Service, repo resolvers.CustomerAccountRepo, diff --git a/app/config/config.go b/app/config/config.go index 27a49cbf7d..0c9992418e 100644 --- a/app/config/config.go +++ b/app/config/config.go @@ -144,6 +144,10 @@ func (c Configuration) Validate() error { errs = append(errs, errorsx.WithPrefix(err, "billing")) } + if err := c.Credits.Validate(); err != nil { + errs = append(errs, errorsx.WithPrefix(err, "credits")) + } + if err := c.Apps.Validate(); err != nil { errs = append(errs, errorsx.WithPrefix(err, "apps")) } @@ -209,6 +213,7 @@ func SetViperDefaults(v *viper.Viper, flags *pflag.FlagSet) { ConfigureEvents(v) ConfigureBalanceWorker(v) ConfigureNotification(v) + ConfigureCredits(v) ConfigureBilling(v, flags) ConfigureProductCatalog(v) ConfigureApps(v, flags) diff --git a/app/config/config_test.go b/app/config/config_test.go index fb2fe0e746..bbea0e74aa 100644 --- a/app/config/config_test.go +++ b/app/config/config_test.go @@ -186,6 +186,9 @@ func TestComplete(t *testing.T) { }, }, }, + Credits: CreditsConfiguration{ + Enabled: false, + }, Sink: SinkConfiguration{ GroupId: "openmeter-sink-worker", MinCommitCount: 500, @@ -432,9 +435,6 @@ func TestComplete(t *testing.T) { `^_\..*$`, `^openmeter\..*$`, }, - Credits: CreditsConfiguration{ - Enabled: false, - }, } assert.Equal(t, expected, actual) diff --git a/cmd/server/main.go b/cmd/server/main.go index b97014448f..87b3bf97d7 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -148,9 +148,11 @@ func main() { Billing: app.BillingRegistry.Billing, BillingInvoicePendingLines: app.BillingRegistry.InvoicePendingLinesService(), BillingFeatureSwitches: conf.Billing.FeatureSwitches, + Credits: conf.Credits, CurrencyService: app.CurrencyService, CostService: app.CostService, Customer: app.Customer, + CustomerBalanceFacade: app.CustomerBalanceFacade, DebugConnector: debugConnector, ErrorHandler: errorsx.NewSlogHandler(logger), EntitlementBalanceConnector: app.EntitlementRegistry.MeteredEntitlement, @@ -165,7 +167,6 @@ func main() { MeterEventService: app.MeterEventService, NamespaceManager: app.NamespaceManager, Notification: app.Notification, - Credits: conf.Credits, Plan: app.Plan, PlanAddon: app.PlanAddon, PlanSubscriptionService: app.Subscription.PlanSubscriptionService, diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 5bf7c63afd..bb927ffcf6 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -20,6 +20,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ingest" "github.com/openmeterio/openmeter/openmeter/ingest/kafkaingest" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/meterevent" @@ -55,6 +56,7 @@ type Application struct { BillingRegistry common.BillingRegistry CurrencyService currencies.CurrencyService CostService cost.Service + CustomerBalanceFacade *customerbalance.Facade EntClient *db.Client EventPublisher eventbus.Publisher EntitlementRegistry *registry.Entitlement @@ -100,6 +102,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl common.Config, common.Currency, common.Customer, + common.CustomerBalance, common.NewCustomerSubjectServiceHook, common.NewCustomerEntitlementValidatorServiceHook, common.Database, @@ -109,6 +112,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl common.Kafka, common.KafkaIngest, common.LLMCost, + common.LedgerStack, common.KafkaNamespaceResolver, common.MeterManageWithConfigMeters, common.MeterEvent, @@ -127,7 +131,6 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl common.Server, common.TaxCode, common.Subscription, - common.LedgerStack, common.Lockr, common.Secret, common.ServerProvisionTopics, diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go index 0c47e78d20..6fd27f53f6 100644 --- a/cmd/server/wire_gen.go +++ b/cmd/server/wire_gen.go @@ -17,6 +17,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ingest" "github.com/openmeterio/openmeter/openmeter/ingest/kafkaingest" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/meterevent" @@ -460,6 +461,28 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl cleanup() return Application{}, nil, err } + customerbalanceService, err := common.NewCustomerBalanceService(logger, client, locker, ledger, accountResolver, accountService, billingRegistry, featureConnector, ratingService, connector) + if err != nil { + cleanup7() + cleanup6() + cleanup5() + cleanup4() + cleanup3() + cleanup2() + cleanup() + return Application{}, nil, err + } + facade, err := common.NewCustomerBalanceFacade(customerbalanceService) + if err != nil { + cleanup7() + cleanup6() + cleanup5() + cleanup4() + cleanup3() + cleanup2() + cleanup() + return Application{}, nil, err + } dedupeConfiguration := conf.Dedupe producer, err := common.NewKafkaProducer(kafkaIngestConfiguration, logger, commonMetadata) if err != nil { @@ -698,6 +721,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl BillingRegistry: billingRegistry, CurrencyService: currencyService, CostService: costService, + CustomerBalanceFacade: facade, EntClient: client, EventPublisher: eventbusPublisher, EntitlementRegistry: entitlement, @@ -761,6 +785,7 @@ type Application struct { BillingRegistry common.BillingRegistry CurrencyService currencies.CurrencyService CostService cost.Service + CustomerBalanceFacade *customerbalance.Facade EntClient *db.Client EventPublisher eventbus.Publisher EntitlementRegistry *registry.Entitlement diff --git a/config.example.yaml b/config.example.yaml index b7dbe50917..6679215e65 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -74,6 +74,9 @@ billing: # for production deployments it's recommended to use queued for server only # advancementStrategy: foreground +credits: + enabled: true + apps: baseURL: https://example.com # stripe: diff --git a/openmeter/billing/charges/charge.go b/openmeter/billing/charges/charge.go index 4a35b0689c..b111a889be 100644 --- a/openmeter/billing/charges/charge.go +++ b/openmeter/billing/charges/charge.go @@ -8,6 +8,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/pkg/framework/entutils" ) @@ -131,6 +132,25 @@ func (c Charge) GetChargeID() (meta.ChargeID, error) { return meta.ChargeID{}, fmt.Errorf("invalid charge type: %s", c.t) } +func (c Charge) SettlementMode() (productcatalog.SettlementMode, error) { + switch c.t { + case meta.ChargeTypeFlatFee: + if c.flatFee == nil { + return "", fmt.Errorf("flat fee charge is nil") + } + + return c.flatFee.Intent.SettlementMode, nil + case meta.ChargeTypeUsageBased: + if c.usageBased == nil { + return "", fmt.Errorf("usage based charge is nil") + } + + return c.usageBased.Intent.SettlementMode, nil + default: + return "", fmt.Errorf("settlement mode is not supported for charge type %s", c.t) + } +} + var _ entutils.InIDOrderAccessor = (*Charge)(nil) func (c Charge) GetID() string { diff --git a/openmeter/billing/charges/usagebased/service.go b/openmeter/billing/charges/usagebased/service.go index b0e5346b85..42ec6bcc81 100644 --- a/openmeter/billing/charges/usagebased/service.go +++ b/openmeter/billing/charges/usagebased/service.go @@ -5,9 +5,12 @@ import ( "errors" "fmt" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" + "github.com/openmeterio/openmeter/openmeter/billing/models/totals" "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" ) @@ -21,6 +24,7 @@ type UsageBasedService interface { GetByIDs(ctx context.Context, input GetByIDsInput) ([]Charge, error) AdvanceCharge(ctx context.Context, input AdvanceChargeInput) (*Charge, error) TriggerPatch(ctx context.Context, charge meta.ChargeID, patch meta.Patch) (*Charge, error) + GetCurrentTotals(ctx context.Context, input GetCurrentTotalsInput) (GetCurrentTotalsResult, error) } type InvoiceLifecycleHooks interface { @@ -119,3 +123,21 @@ func (i GetByIDInput) Validate() error { return errors.Join(errs...) } + +type GetCurrentTotalsInput struct { + ChargeID meta.ChargeID +} + +func (i GetCurrentTotalsInput) Validate() error { + if err := i.ChargeID.Validate(); err != nil { + return fmt.Errorf("charge ID: %w", err) + } + + return nil +} + +type GetCurrentTotalsResult struct { + Charge Charge + Quantity alpacadecimal.Decimal + DueTotals totals.Totals +} diff --git a/openmeter/billing/charges/usagebased/service/currenttotals.go b/openmeter/billing/charges/usagebased/service/currenttotals.go new file mode 100644 index 0000000000..5bee80863c --- /dev/null +++ b/openmeter/billing/charges/usagebased/service/currenttotals.go @@ -0,0 +1,65 @@ +package service + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/pkg/clock" +) + +func (s *service) GetCurrentTotals(ctx context.Context, input usagebased.GetCurrentTotalsInput) (usagebased.GetCurrentTotalsResult, error) { + if err := input.Validate(); err != nil { + return usagebased.GetCurrentTotalsResult{}, err + } + + charge, err := s.adapter.GetByID(ctx, usagebased.GetByIDInput{ + ChargeID: input.ChargeID, + Expands: meta.Expands{meta.ExpandRealizations}, + }) + if err != nil { + return usagebased.GetCurrentTotalsResult{}, fmt.Errorf("get charge: %w", err) + } + + customerOverride, err := s.customerOverrideService.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{ + Customer: customer.CustomerID{ + Namespace: charge.Namespace, + ID: charge.Intent.CustomerID, + }, + Expand: billing.CustomerOverrideExpand{ + Customer: true, + }, + }) + if err != nil { + return usagebased.GetCurrentTotalsResult{}, fmt.Errorf("get customer override: %w", err) + } + + featureMeters, err := s.featureService.ResolveFeatureMeters(ctx, charge.Namespace, []string{charge.Intent.FeatureKey}) + if err != nil { + return usagebased.GetCurrentTotalsResult{}, fmt.Errorf("resolve feature meters: %w", err) + } + + featureMeter, err := featureMeters.Get(charge.Intent.FeatureKey, true) + if err != nil { + return usagebased.GetCurrentTotalsResult{}, fmt.Errorf("get feature meter: %w", err) + } + + ratingResult, err := s.getRatingForUsage(ctx, getRatingForUsageInput{ + Charge: charge, + Customer: customerOverride, + FeatureMeter: featureMeter, + StoredAtOffset: clock.Now(), + }) + if err != nil { + return usagebased.GetCurrentTotalsResult{}, fmt.Errorf("get rating for usage: %w", err) + } + + return usagebased.GetCurrentTotalsResult{ + Charge: charge, + Quantity: ratingResult.Quantity, + DueTotals: ratingResult.Totals, + }, nil +} diff --git a/openmeter/ledger/customerbalance/calculation.go b/openmeter/ledger/customerbalance/calculation.go new file mode 100644 index 0000000000..961e8016e8 --- /dev/null +++ b/openmeter/ledger/customerbalance/calculation.go @@ -0,0 +1,107 @@ +package customerbalance + +import ( + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/productcatalog" +) + +type Impact struct { + charges.Charge + + amount alpacadecimal.Decimal +} + +func NewImpact(charge charges.Charge, amount alpacadecimal.Decimal) (Impact, error) { + if _, err := charge.SettlementMode(); err != nil { + return Impact{}, err + } + + return Impact{ + Charge: charge, + amount: amount, + }, nil +} + +func (i Impact) OutstandingAmount() alpacadecimal.Decimal { + amount := i.amount.Sub(i.RealizedCredits()) + if amount.IsNegative() { + return alpacadecimal.Zero + } + + return amount +} + +func (i Impact) RealizedCredits() alpacadecimal.Decimal { + switch i.Type() { + case meta.ChargeTypeFlatFee: + charge, _ := i.AsFlatFeeCharge() + return charge.State.CreditRealizations.Sum() + case meta.ChargeTypeUsageBased: + charge, _ := i.AsUsageBasedCharge() + total := alpacadecimal.Zero + + for _, run := range charge.Realizations { + total = total.Add(run.CreditsAllocated.Sum()) + } + + return total + default: + return alpacadecimal.Zero + } +} + +func (i Impact) BoundedAmount() alpacadecimal.Decimal { + settlementMode, err := i.SettlementMode() + if err != nil || settlementMode != productcatalog.CreditThenInvoiceSettlementMode { + return alpacadecimal.Zero + } + + return i.OutstandingAmount() +} + +func (i Impact) UnboundedAmount() alpacadecimal.Decimal { + settlementMode, err := i.SettlementMode() + if err != nil || settlementMode != productcatalog.CreditOnlySettlementMode { + return alpacadecimal.Zero + } + + return i.OutstandingAmount() +} + +type chargePendingBalanceCalculator struct{} + +func (chargePendingBalanceCalculator) CalculatePendingBalance(bookedBalance alpacadecimal.Decimal, impacts []Impact) alpacadecimal.Decimal { + boundedAmount, unboundedAmount := sumImpactAmounts(impacts) + + // credit_then_invoice can only consume positive balance, while credit_only can drive it negative. + pendingBalance := applyBoundedAmount(bookedBalance, boundedAmount) + + return pendingBalance.Sub(unboundedAmount) +} + +func sumImpactAmounts(impacts []Impact) (bounded alpacadecimal.Decimal, unbounded alpacadecimal.Decimal) { + bounded = alpacadecimal.Zero + unbounded = alpacadecimal.Zero + + for _, impact := range impacts { + bounded = bounded.Add(impact.BoundedAmount()) + unbounded = unbounded.Add(impact.UnboundedAmount()) + } + + return bounded, unbounded +} + +func applyBoundedAmount(balance alpacadecimal.Decimal, amount alpacadecimal.Decimal) alpacadecimal.Decimal { + if !balance.GreaterThan(alpacadecimal.Zero) { + return balance + } + + if amount.GreaterThan(balance) { + return alpacadecimal.Zero + } + + return balance.Sub(amount) +} diff --git a/openmeter/ledger/customerbalance/facade.go b/openmeter/ledger/customerbalance/facade.go new file mode 100644 index 0000000000..0f0c7b3854 --- /dev/null +++ b/openmeter/ledger/customerbalance/facade.go @@ -0,0 +1,128 @@ +package customerbalance + +import ( + "context" + "errors" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +type CurrencyFilter struct { + Codes []currencyx.Code +} + +func (f CurrencyFilter) Validate() error { + for _, code := range f.Codes { + if code == "" { + return errors.New("currency code is required") + } + } + + return nil +} + +type GetBalancesInput struct { + CustomerID customer.CustomerID + Currencies CurrencyFilter +} + +func (i GetBalancesInput) Validate() error { + var errs []error + + if err := i.CustomerID.Validate(); err != nil { + errs = append(errs, fmt.Errorf("customer ID: %w", err)) + } + + if err := i.Currencies.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currencies: %w", err)) + } + + return errors.Join(errs...) +} + +type BalanceByCurrency struct { + Currency currencyx.Code + Balance ledger.Balance +} + +type Facade struct { + service *Service +} + +func NewFacade(service *Service) (*Facade, error) { + if service == nil { + return nil, errors.New("service is required") + } + + return &Facade{ + service: service, + }, nil +} + +func (f *Facade) GetBalances(ctx context.Context, input GetBalancesInput) ([]BalanceByCurrency, error) { + if f == nil { + return nil, errors.New("facade is required") + } + + if err := input.Validate(); err != nil { + return nil, err + } + + var codes []currencyx.Code + if len(input.Currencies.Codes) > 0 { + codes = dedupeCurrencies(input.Currencies.Codes) + + for _, code := range codes { + if err := code.Validate(); err != nil { + return nil, fmt.Errorf("currency %q is not supported by ledger: %w", code, err) + } + } + } else { + var err error + + codes, err = f.service.getFBOCurrencies(ctx, input.CustomerID) + if err != nil { + return nil, fmt.Errorf("get FBO currencies: %w", err) + } + } + + balances := make([]BalanceByCurrency, 0, len(codes)) + for _, code := range codes { + balance, err := f.service.GetBalance(ctx, input.CustomerID, routeFilter(code)) + if err != nil { + return nil, err + } + + balances = append(balances, BalanceByCurrency{ + Currency: code, + Balance: balance, + }) + } + + return balances, nil +} + +func routeFilter(currency currencyx.Code) ledger.RouteFilter { + return ledger.RouteFilter{ + Currency: currency, + } +} + +func dedupeCurrencies(codes []currencyx.Code) []currencyx.Code { + seen := make(map[currencyx.Code]struct{}, len(codes)) + out := make([]currencyx.Code, 0, len(codes)) + + for _, code := range codes { + if _, ok := seen[code]; ok { + continue + } + + seen[code] = struct{}{} + out = append(out, code) + } + + return out +} diff --git a/openmeter/ledger/customerbalance/facade_test.go b/openmeter/ledger/customerbalance/facade_test.go new file mode 100644 index 0000000000..4834986dcd --- /dev/null +++ b/openmeter/ledger/customerbalance/facade_test.go @@ -0,0 +1,91 @@ +package customerbalance + +import ( + "testing" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +func TestFacadeGetBalancesWithExplicitCurrencies(t *testing.T) { + env := newTestEnv(t) + + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(100), "USD") + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(200), "EUR") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(30), productcatalog.CreditOnlySettlementMode, env.sp(), "USD") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(70), productcatalog.CreditOnlySettlementMode, env.sp(), "EUR") + + facade, err := NewFacade(env.Service) + require.NoError(t, err) + + balances, err := facade.GetBalances(t.Context(), GetBalancesInput{ + CustomerID: env.CustomerID, + Currencies: CurrencyFilter{ + Codes: []currencyx.Code{"USD", "EUR"}, + }, + }) + require.NoError(t, err) + require.Len(t, balances, 2) + + require.Equal(t, currencyx.Code("USD"), balances[0].Currency) + require.True(t, balances[0].Balance.Settled().Equal(alpacadecimal.NewFromInt(100))) + require.True(t, balances[0].Balance.Pending().Equal(alpacadecimal.NewFromInt(70))) + + require.Equal(t, currencyx.Code("EUR"), balances[1].Currency) + require.True(t, balances[1].Balance.Settled().Equal(alpacadecimal.NewFromInt(200))) + require.True(t, balances[1].Balance.Pending().Equal(alpacadecimal.NewFromInt(130))) +} + +func TestFacadeGetBalancesWithDiscoveredCurrencies(t *testing.T) { + env := newTestEnv(t) + + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(100), "USD") + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(200), "EUR") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(30), productcatalog.CreditOnlySettlementMode, env.sp(), "USD") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(70), productcatalog.CreditOnlySettlementMode, env.sp(), "EUR") + facade, err := NewFacade(env.Service) + require.NoError(t, err) + + balances, err := facade.GetBalances(t.Context(), GetBalancesInput{ + CustomerID: env.CustomerID, + }) + require.NoError(t, err) + require.Len(t, balances, 2) + + var usdCount, eurCount int + for _, balance := range balances { + switch balance.Currency { + case "USD": + usdCount++ + require.True(t, balance.Balance.Settled().Equal(alpacadecimal.NewFromInt(100))) + require.True(t, balance.Balance.Pending().Equal(alpacadecimal.NewFromInt(70))) + case "EUR": + eurCount++ + require.True(t, balance.Balance.Settled().Equal(alpacadecimal.NewFromInt(200))) + require.True(t, balance.Balance.Pending().Equal(alpacadecimal.NewFromInt(130))) + } + } + + require.Equal(t, 1, usdCount) + require.Equal(t, 1, eurCount) +} + +func TestFacadeGetBalancesWithUnsupportedExplicitCurrency(t *testing.T) { + env := newTestEnv(t) + + facade, err := NewFacade(env.Service) + require.NoError(t, err) + + _, err = facade.GetBalances(t.Context(), GetBalancesInput{ + CustomerID: env.CustomerID, + Currencies: CurrencyFilter{ + Codes: []currencyx.Code{"CUSTOM"}, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "CUSTOM") + require.ErrorContains(t, err, "not supported by ledger") +} diff --git a/openmeter/ledger/customerbalance/service.go b/openmeter/ledger/customerbalance/service.go new file mode 100644 index 0000000000..ccdce62119 --- /dev/null +++ b/openmeter/ledger/customerbalance/service.go @@ -0,0 +1,331 @@ +package customerbalance + +import ( + "context" + "errors" + "fmt" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +// ---------------------------------------------------------------------------- +// Dependency interfaces +// ---------------------------------------------------------------------------- + +type chargesService interface { + ListCharges(ctx context.Context, input charges.ListChargesInput) (pagination.Result[charges.Charge], error) +} + +type subAccountLister interface { + ListSubAccounts(ctx context.Context, input ledgeraccount.ListSubAccountsInput) ([]*ledgeraccount.SubAccount, error) +} + +type usageBasedTotalsService interface { + GetCurrentTotals(ctx context.Context, input usagebased.GetCurrentTotalsInput) (usagebased.GetCurrentTotalsResult, error) +} + +const chargeListPageSize = 100 + +// ---------------------------------------------------------------------------- +// Service +// ---------------------------------------------------------------------------- + +// service is NOT the RTE (Real Time Engine) +// - it is a simple service to bridge the gap until we get to implementing the RTE +// - this should be used for balance queries until the RTE is implemented +type Service struct { + AccountResolver ledger.AccountResolver + SubAccountService subAccountLister + ChargesService chargesService + UsageBasedService usageBasedTotalsService + + balanceCalculator chargePendingBalanceCalculator +} + +type Config struct { + AccountResolver ledger.AccountResolver + SubAccountService subAccountLister + ChargesService chargesService + UsageBasedService usageBasedTotalsService +} + +func (c Config) Validate() error { + var errs []error + + if c.AccountResolver == nil { + errs = append(errs, errors.New("account resolver is required")) + } + + if c.SubAccountService == nil { + errs = append(errs, errors.New("sub account service is required")) + } + + if c.ChargesService == nil { + errs = append(errs, errors.New("charges service is required")) + } + + if c.UsageBasedService == nil { + errs = append(errs, errors.New("usage based service is required")) + } + + return errors.Join(errs...) +} + +func New(config Config) (*Service, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &Service{ + AccountResolver: config.AccountResolver, + SubAccountService: config.SubAccountService, + ChargesService: config.ChargesService, + UsageBasedService: config.UsageBasedService, + balanceCalculator: chargePendingBalanceCalculator{}, + }, nil +} + +func (s *Service) GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter) (ledger.Balance, error) { + if err := s.validate(customerID, filters); err != nil { + return nil, err + } + + customerAccounts, err := s.AccountResolver.GetCustomerAccounts(ctx, customerID) + if err != nil { + return nil, fmt.Errorf("get customer accounts: %w", err) + } + + bookedBalance, err := customerAccounts.FBOAccount.GetBalance(ctx, filters) + if err != nil { + return nil, fmt.Errorf("get booked balance: %w", err) + } + + impacts, err := s.getChargePendingBalanceImpacts(ctx, customerID, filters.Currency) + if err != nil { + return nil, fmt.Errorf("get charge pending balance impacts: %w", err) + } + + return balance{ + settled: bookedBalance.Settled(), + pending: s.balanceCalculator.CalculatePendingBalance(bookedBalance.Pending(), impacts), + }, nil +} + +func (s *Service) getFBOCurrencies(ctx context.Context, customerID customer.CustomerID) ([]currencyx.Code, error) { + customerAccounts, err := s.AccountResolver.GetCustomerAccounts(ctx, customerID) + if err != nil { + return nil, fmt.Errorf("get customer accounts: %w", err) + } + + fboAccount, ok := customerAccounts.FBOAccount.(*ledgeraccount.CustomerFBOAccount) + if !ok { + return nil, fmt.Errorf("customer FBO account: unexpected type %T", customerAccounts.FBOAccount) + } + + subAccounts, err := s.SubAccountService.ListSubAccounts(ctx, ledgeraccount.ListSubAccountsInput{ + Namespace: fboAccount.ID().Namespace, + AccountID: fboAccount.ID().ID, + }) + if err != nil { + return nil, fmt.Errorf("list sub accounts: %w", err) + } + + seen := make(map[currencyx.Code]struct{}, len(subAccounts)) + codes := make([]currencyx.Code, 0, len(subAccounts)) + + for _, sa := range subAccounts { + c := sa.Route().Currency + if _, ok := seen[c]; ok { + continue + } + + seen[c] = struct{}{} + codes = append(codes, c) + } + + return codes, nil +} + +func (s *Service) getChargePendingBalanceImpacts(ctx context.Context, customerID customer.CustomerID, currency currencyx.Code) ([]Impact, error) { + items, err := pagination.CollectAll( + ctx, + pagination.NewPaginator(func(ctx context.Context, page pagination.Page) (pagination.Result[charges.Charge], error) { + return s.ChargesService.ListCharges(ctx, charges.ListChargesInput{ + Page: page, + Namespace: customerID.Namespace, + CustomerIDs: []string{customerID.ID}, + ChargeTypes: []meta.ChargeType{ + meta.ChargeTypeFlatFee, + meta.ChargeTypeUsageBased, + }, + StatusNotIn: []meta.ChargeStatus{meta.ChargeStatusFinal}, + Expands: meta.Expands{meta.ExpandRealizations}, + }) + }), + chargeListPageSize, + ) + if err != nil { + return nil, fmt.Errorf("list charges: %w", err) + } + + impacts := make([]Impact, 0, len(items)) + for _, charge := range items { + impact, err := s.getChargePendingBalanceImpact(ctx, charge, currency) + if err != nil { + return nil, err + } + + if impact == nil { + continue + } + + impacts = append(impacts, *impact) + } + + return impacts, nil +} + +func (s *Service) getChargePendingBalanceImpact(ctx context.Context, charge charges.Charge, currency currencyx.Code) (*Impact, error) { + if !chargeHasStarted(charge) { + return nil, nil + } + + switch charge.Type() { + case meta.ChargeTypeFlatFee: + return getFlatFeeChargePendingBalanceImpact(charge, currency) + case meta.ChargeTypeUsageBased: + return s.getUsageBasedChargePendingBalanceImpact(ctx, charge, currency) + default: + return nil, nil + } +} + +func getFlatFeeChargePendingBalanceImpact(charge charges.Charge, currency currencyx.Code) (*Impact, error) { + flatFeeCharge, err := charge.AsFlatFeeCharge() + if err != nil { + return nil, fmt.Errorf("map flat fee charge: %w", err) + } + + if flatFeeCharge.Intent.Currency != currency { + return nil, nil + } + + return newImpactOrNil(charge, flatFeeCharge.State.AmountAfterProration) +} + +func (s *Service) getUsageBasedChargePendingBalanceImpact(ctx context.Context, charge charges.Charge, currency currencyx.Code) (*Impact, error) { + usageBasedCharge, err := charge.AsUsageBasedCharge() + if err != nil { + return nil, fmt.Errorf("map usage based charge: %w", err) + } + + if usageBasedCharge.Intent.Currency != currency { + return nil, nil + } + + currentTotals, err := s.UsageBasedService.GetCurrentTotals(ctx, usagebased.GetCurrentTotalsInput{ + ChargeID: usageBasedCharge.GetChargeID(), + }) + if err != nil { + return nil, fmt.Errorf("get current totals for charge %s: %w", usageBasedCharge.ID, err) + } + + return newImpactOrNil(charges.NewCharge(currentTotals.Charge), currentTotals.DueTotals.Total) +} + +func chargeHasStarted(charge charges.Charge) bool { + now := clock.Now() + + switch charge.Type() { + case meta.ChargeTypeFlatFee: + flatFeeCharge, err := charge.AsFlatFeeCharge() + if err != nil { + return false + } + + return !now.Before(flatFeeCharge.Intent.ServicePeriod.From) + case meta.ChargeTypeUsageBased: + usageBasedCharge, err := charge.AsUsageBasedCharge() + if err != nil { + return false + } + + return !now.Before(usageBasedCharge.Intent.ServicePeriod.From) + default: + return false + } +} + +func newImpactOrNil(charge charges.Charge, amount alpacadecimal.Decimal) (*Impact, error) { + impact, err := NewImpact(charge, amount) + if err != nil { + return nil, err + } + + if impact.OutstandingAmount().IsZero() { + return nil, nil + } + + return &impact, nil +} + +func (s *Service) validate(customerID customer.CustomerID, filters ledger.RouteFilter) error { + var errs []error + + if s == nil { + errs = append(errs, errors.New("service is required")) + } else { + if s.AccountResolver == nil { + errs = append(errs, errors.New("account resolver is required")) + } + + if s.SubAccountService == nil { + errs = append(errs, errors.New("sub account service is required")) + } + + if s.ChargesService == nil { + errs = append(errs, errors.New("charges service is required")) + } + + if s.UsageBasedService == nil { + errs = append(errs, errors.New("usage based service is required")) + } + } + + if err := customerID.Validate(); err != nil { + errs = append(errs, fmt.Errorf("customer ID: %w", err)) + } + + if filters.Currency == "" { + errs = append(errs, errors.New("currency filter is required")) + } + + if _, err := filters.Normalize(); err != nil { + errs = append(errs, fmt.Errorf("route filter: %w", err)) + } + + return errors.Join(errs...) +} + +type balance struct { + settled alpacadecimal.Decimal + pending alpacadecimal.Decimal +} + +func (b balance) Settled() alpacadecimal.Decimal { + return b.settled +} + +func (b balance) Pending() alpacadecimal.Decimal { + return b.pending +} diff --git a/openmeter/ledger/customerbalance/service_test.go b/openmeter/ledger/customerbalance/service_test.go new file mode 100644 index 0000000000..625fecc2af --- /dev/null +++ b/openmeter/ledger/customerbalance/service_test.go @@ -0,0 +1,149 @@ +package customerbalance + +import ( + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +func TestGetBalance(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, env *testEnv) + wantSettled int64 + wantPending int64 + }{ + { + name: "flat fee credit only", + setup: func(t *testing.T, env *testEnv) { + env.bookFBOBalance(t, alpacadecimal.NewFromInt(100)) + env.createFlatFeeCharge(t, alpacadecimal.NewFromInt(30), productcatalog.CreditOnlySettlementMode, env.sp()) + }, + wantSettled: 100, + wantPending: 70, + }, + { + name: "flat fee credit then invoice", + setup: func(t *testing.T, env *testEnv) { + env.bookFBOBalance(t, alpacadecimal.NewFromInt(20)) + env.createFlatFeeCharge(t, alpacadecimal.NewFromInt(30), productcatalog.CreditThenInvoiceSettlementMode, env.sp()) + }, + wantSettled: 20, + wantPending: 0, + }, + { + name: "usage based credit only", + setup: func(t *testing.T, env *testEnv) { + env.addUsage(30, clock.Now().Add(-30*time.Minute)) + env.bookFBOBalance(t, alpacadecimal.NewFromInt(100)) + env.createUsageBasedCharge(t, alpacadecimal.NewFromInt(1), productcatalog.CreditOnlySettlementMode, env.sp()) + }, + wantSettled: 100, + wantPending: 70, + }, + { + name: "usage based credit then invoice", + setup: func(t *testing.T, env *testEnv) { + env.addUsage(30, clock.Now().Add(-30*time.Minute)) + env.bookFBOBalance(t, alpacadecimal.NewFromInt(20)) + env.createUsageBasedCharge(t, alpacadecimal.NewFromInt(1), productcatalog.CreditThenInvoiceSettlementMode, env.sp()) + }, + wantSettled: 20, + wantPending: 0, + }, + { + name: "mixed modes are pessimistic", + setup: func(t *testing.T, env *testEnv) { + env.addUsage(150, clock.Now().Add(-30*time.Minute)) + env.bookFBOBalance(t, alpacadecimal.NewFromInt(100)) + env.createFlatFeeCharge(t, alpacadecimal.NewFromInt(80), productcatalog.CreditThenInvoiceSettlementMode, env.sp()) + env.createUsageBasedCharge(t, alpacadecimal.NewFromInt(1), productcatalog.CreditOnlySettlementMode, env.sp()) + }, + wantSettled: 100, + wantPending: -130, + }, + { + name: "future charges are excluded until service period starts", + setup: func(t *testing.T, env *testEnv) { + futureServicePeriod := timeutil.ClosedPeriod{ + From: clock.Now().Add(time.Hour), + To: clock.Now().Add(2 * time.Hour), + } + + env.addUsage(30, clock.Now().Add(-30*time.Minute)) + env.bookFBOBalance(t, alpacadecimal.NewFromInt(100)) + env.createFlatFeeCharge(t, alpacadecimal.NewFromInt(30), productcatalog.CreditOnlySettlementMode, futureServicePeriod) + env.createUsageBasedCharge(t, alpacadecimal.NewFromInt(1), productcatalog.CreditOnlySettlementMode, futureServicePeriod) + }, + wantSettled: 100, + wantPending: 100, + }, + { + name: "already realized credits are not applied twice", + setup: func(t *testing.T, env *testEnv) { + env.bookFBOBalance(t, alpacadecimal.NewFromInt(70)) + + charge := env.createFlatFeeCharge(t, + alpacadecimal.NewFromInt(30), + productcatalog.CreditOnlySettlementMode, + env.sp(), + ) + + env.advanceFlatFeeCharge(t, charge) + }, + wantSettled: 70, + wantPending: 70, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + env := newTestEnv(t) + tt.setup(t, env) + + priority := ledger.DefaultCustomerFBOPriority + balance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ + Currency: env.Currency, + CreditPriority: &priority, + }) + require.NoError(t, err) + require.True(t, balance.Settled().Equal(alpacadecimal.NewFromInt(tt.wantSettled))) + require.True(t, balance.Pending().Equal(alpacadecimal.NewFromInt(tt.wantPending))) + }) + } +} + +func TestGetBalanceWithDifferentCurrency(t *testing.T) { + env := newTestEnv(t) + + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(100), "USD") + env.bookFBOBalanceInCurrency(t, alpacadecimal.NewFromInt(200), "EUR") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(30), productcatalog.CreditOnlySettlementMode, env.sp(), "USD") + env.createFlatFeeChargeInCurrency(t, alpacadecimal.NewFromInt(70), productcatalog.CreditOnlySettlementMode, env.sp(), "EUR") + + usdPriority := ledger.DefaultCustomerFBOPriority + usdBalance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ + Currency: currencyx.Code("USD"), + CreditPriority: &usdPriority, + }) + require.NoError(t, err) + require.True(t, usdBalance.Settled().Equal(alpacadecimal.NewFromInt(100))) + require.True(t, usdBalance.Pending().Equal(alpacadecimal.NewFromInt(70))) + + eurPriority := ledger.DefaultCustomerFBOPriority + eurBalance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ + Currency: currencyx.Code("EUR"), + CreditPriority: &eurPriority, + }) + require.NoError(t, err) + require.True(t, eurBalance.Settled().Equal(alpacadecimal.NewFromInt(200))) + require.True(t, eurBalance.Pending().Equal(alpacadecimal.NewFromInt(130))) +} diff --git a/openmeter/ledger/customerbalance/testenv_test.go b/openmeter/ledger/customerbalance/testenv_test.go new file mode 100644 index 0000000000..b516b9fd12 --- /dev/null +++ b/openmeter/ledger/customerbalance/testenv_test.go @@ -0,0 +1,431 @@ +package customerbalance + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/app" + "github.com/openmeterio/openmeter/openmeter/billing" + charges "github.com/openmeterio/openmeter/openmeter/billing/charges" + chargeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/adapter" + "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + flatfeeadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/adapter" + flatfeeservice "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee/service" + chargemeta "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + metaadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/meta/adapter" + chargestestutils "github.com/openmeterio/openmeter/openmeter/billing/charges/testutils" + "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + usagebasedadapter "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased/adapter" + usagebasedservice "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased/service" + billingratingservice "github.com/openmeterio/openmeter/openmeter/billing/rating/service" + "github.com/openmeterio/openmeter/openmeter/customer" + ledgertestutils "github.com/openmeterio/openmeter/openmeter/ledger/testutils" + "github.com/openmeterio/openmeter/openmeter/ledger/transactions" + "github.com/openmeterio/openmeter/openmeter/meter" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" + streamingtestutils "github.com/openmeterio/openmeter/openmeter/streaming/testutils" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/lockr" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +const ( + testFeatureKey = "api_requests" + testMeterKey = "api_requests" +) + +type testEnv struct { + *ledgertestutils.IntegrationEnv + + Service *Service + flatFeeService flatfee.Service + usageBasedService usagebased.Service + streaming *streamingtestutils.MockStreamingConnector +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + base := ledgertestutils.NewIntegrationEnv(t, "ledger-balance") + logger := slog.New(slog.DiscardHandler) + streaming := streamingtestutils.NewMockStreamingConnector(t) + handlers := chargestestutils.NewMockHandlers() + + billingService := mockCustomerOverrideService{ + customer: customer.Customer{ + ManagedResource: models.ManagedResource{ + NamespacedModel: models.NamespacedModel{ + Namespace: base.CustomerID.Namespace, + }, + ID: base.CustomerID.ID, + Name: "Test Customer", + }, + UsageAttribution: &customer.CustomerUsageAttribution{ + SubjectKeys: []string{"subject-1"}, + }, + }, + } + + featureService := mockFeatureConnector{ + meters: feature.FeatureMeterCollection{ + testFeatureKey: { + Feature: feature.Feature{ + Namespace: base.Namespace, + ID: "feature-1", + Name: "API Requests", + Key: testFeatureKey, + MeterID: lo.ToPtr("meter-1"), + CreatedAt: base.Now(), + UpdatedAt: base.Now(), + }, + Meter: &meter.Meter{ + ManagedResource: models.ManagedResource{ + NamespacedModel: models.NamespacedModel{ + Namespace: base.Namespace, + }, + ID: "meter-1", + Name: "API Requests Meter", + }, + Key: testMeterKey, + Aggregation: meter.MeterAggregationSum, + EventType: "api_request", + }, + }, + }, + } + + metaAdapter, err := metaadapter.New(metaadapter.Config{ + Client: base.DB, + Logger: logger, + }) + require.NoError(t, err) + + locker, err := lockr.NewLocker(&lockr.LockerConfig{ + Logger: logger, + }) + require.NoError(t, err) + + usageAdapter, err := usagebasedadapter.New(usagebasedadapter.Config{ + Client: base.DB, + Logger: logger, + MetaAdapter: metaAdapter, + }) + require.NoError(t, err) + + flatFeeAdapter, err := flatfeeadapter.New(flatfeeadapter.Config{ + Client: base.DB, + Logger: logger, + MetaAdapter: metaAdapter, + }) + require.NoError(t, err) + + flatFeeService, err := flatfeeservice.New(flatfeeservice.Config{ + Adapter: flatFeeAdapter, + Handler: handlers.FlatFee, + MetaAdapter: metaAdapter, + Locker: locker, + }) + require.NoError(t, err) + + usageService, err := usagebasedservice.New(usagebasedservice.Config{ + Adapter: usageAdapter, + Handler: handlers.UsageBased, + Locker: locker, + MetaAdapter: metaAdapter, + CustomerOverrideService: billingService, + FeatureService: featureService, + RatingService: billingratingservice.New(), + StreamingConnector: streaming, + }) + require.NoError(t, err) + + searchAdapter, err := chargeadapter.New(chargeadapter.Config{ + Client: base.DB, + Logger: logger, + }) + require.NoError(t, err) + + service, err := New(Config{ + AccountResolver: base.Deps.ResolversService, + SubAccountService: base.Deps.AccountService, + ChargesService: chargeStore{ + search: searchAdapter, + flatFeeService: flatFeeService, + usageBasedService: usageService, + }, + UsageBasedService: usageService, + }) + require.NoError(t, err) + + env := &testEnv{ + IntegrationEnv: base, + Service: service, + flatFeeService: flatFeeService, + usageBasedService: usageService, + streaming: streaming, + } + + env.createCustomer(t) + + return env +} + +func (e *testEnv) addUsage(value float64, at time.Time) { + e.streaming.AddSimpleEvent(testMeterKey, value, at) +} + +func (e *testEnv) sp() timeutil.ClosedPeriod { + return timeutil.ClosedPeriod{ + From: clock.Now().Add(-time.Hour), + To: clock.Now().Add(time.Hour), + } +} + +// simply currency based backing (balance doesn't care about most dimensions) +func (e *testEnv) bookFBOBalance(t *testing.T, amount alpacadecimal.Decimal) { + e.bookFBOBalanceInCurrency(t, amount, e.Currency) +} + +func (e *testEnv) bookFBOBalanceInCurrency(t *testing.T, amount alpacadecimal.Decimal, currency currencyx.Code) { + t.Helper() + + inputs, err := transactions.ResolveTransactions( + t.Context(), + transactions.ResolverDependencies{ + AccountService: e.Deps.ResolversService, + SubAccountService: e.Deps.AccountService, + }, + transactions.ResolutionScope{ + CustomerID: e.CustomerID, + Namespace: e.Namespace, + }, + transactions.IssueCustomerReceivableTemplate{ + At: e.Now(), + Amount: amount, + Currency: currency, + }, + ) + require.NoError(t, err) + + _, err = e.Deps.HistoricalLedger.CommitGroup(t.Context(), transactions.GroupInputs(e.Namespace, nil, inputs...)) + require.NoError(t, err) +} + +func (e *testEnv) createUsageBasedCharge(t *testing.T, unitPrice alpacadecimal.Decimal, settlementMode productcatalog.SettlementMode, servicePeriod timeutil.ClosedPeriod) usagebased.Charge { + return e.createUsageBasedChargeInCurrency(t, unitPrice, settlementMode, servicePeriod, e.Currency) +} + +func (e *testEnv) createUsageBasedChargeInCurrency(t *testing.T, unitPrice alpacadecimal.Decimal, settlementMode productcatalog.SettlementMode, servicePeriod timeutil.ClosedPeriod, currency currencyx.Code) usagebased.Charge { + t.Helper() + + createdCharges, err := e.usageBasedService.Create(t.Context(), usagebased.CreateInput{ + Namespace: e.Namespace, + Intents: []usagebased.Intent{ + { + Intent: chargemeta.Intent{ + Name: "API Requests", + ManagedBy: billing.SystemManagedLine, + CustomerID: e.CustomerID.ID, + Currency: currency, + ServicePeriod: servicePeriod, + FullServicePeriod: servicePeriod, + BillingPeriod: servicePeriod, + }, + InvoiceAt: e.Now().Add(-time.Minute), + SettlementMode: settlementMode, + FeatureKey: testFeatureKey, + Price: *productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: unitPrice}), + }, + }, + }) + require.NoError(t, err) + require.Len(t, createdCharges, 1) + + return createdCharges[0].Charge +} + +func (e *testEnv) createFlatFeeCharge(t *testing.T, amount alpacadecimal.Decimal, settlementMode productcatalog.SettlementMode, servicePeriod timeutil.ClosedPeriod) flatfee.Charge { + return e.createFlatFeeChargeInCurrency(t, amount, settlementMode, servicePeriod, e.Currency) +} + +func (e *testEnv) createFlatFeeChargeInCurrency(t *testing.T, amount alpacadecimal.Decimal, settlementMode productcatalog.SettlementMode, servicePeriod timeutil.ClosedPeriod, currency currencyx.Code) flatfee.Charge { + t.Helper() + + createdCharges, err := e.flatFeeService.Create(t.Context(), flatfee.CreateInput{ + Namespace: e.Namespace, + Intents: []flatfee.Intent{ + { + Intent: chargemeta.Intent{ + Name: "Platform Fee", + ManagedBy: billing.SystemManagedLine, + CustomerID: e.CustomerID.ID, + Currency: currency, + ServicePeriod: servicePeriod, + FullServicePeriod: servicePeriod, + BillingPeriod: servicePeriod, + }, + InvoiceAt: e.Now().Add(-time.Minute), + SettlementMode: settlementMode, + PaymentTerm: productcatalog.InAdvancePaymentTerm, + AmountBeforeProration: amount, + }, + }, + }) + require.NoError(t, err) + require.Len(t, createdCharges, 1) + + return createdCharges[0].Charge +} + +func (e *testEnv) advanceFlatFeeCharge(t *testing.T, charge flatfee.Charge) flatfee.Charge { + t.Helper() + + advancedCharge, err := e.flatFeeService.AdvanceCharge(t.Context(), flatfee.AdvanceChargeInput{ + ChargeID: charge.GetChargeID(), + }) + require.NoError(t, err) + require.NotNil(t, advancedCharge) + + return *advancedCharge +} + +func (e *testEnv) createCustomer(t *testing.T) { + t.Helper() + + _, err := e.DB.Customer.Create(). + SetNamespace(e.Namespace). + SetID(e.CustomerID.ID). + SetName("Test Customer"). + Save(t.Context()) + require.NoError(t, err) +} + +type chargeStore struct { + search charges.ChargesSearchAdapter + flatFeeService flatfee.Service + usageBasedService usagebased.Service +} + +func (l chargeStore) ListCharges(ctx context.Context, input charges.ListChargesInput) (pagination.Result[charges.Charge], error) { + searchResult, err := l.search.ListCharges(ctx, input) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + flatFeeIDs := make([]string, 0, len(searchResult.Items)) + usageBasedIDs := make([]string, 0, len(searchResult.Items)) + for _, item := range searchResult.Items { + switch item.Type { + case chargemeta.ChargeTypeFlatFee: + flatFeeIDs = append(flatFeeIDs, item.ID.ID) + case chargemeta.ChargeTypeUsageBased: + usageBasedIDs = append(usageBasedIDs, item.ID.ID) + } + } + + flatFeeCharges, err := l.flatFeeService.GetByIDs(ctx, flatfee.GetByIDsInput{ + Namespace: input.Namespace, + IDs: flatFeeIDs, + Expands: input.Expands, + }) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + usageBasedCharges, err := l.usageBasedService.GetByIDs(ctx, usagebased.GetByIDsInput{ + Namespace: input.Namespace, + IDs: usageBasedIDs, + Expands: input.Expands, + }) + if err != nil { + return pagination.Result[charges.Charge]{}, err + } + + chargesByID := make(map[string]charges.Charge, len(flatFeeCharges)+len(usageBasedCharges)) + for _, charge := range flatFeeCharges { + chargesByID[charge.ID] = charges.NewCharge(charge) + } + for _, charge := range usageBasedCharges { + chargesByID[charge.ID] = charges.NewCharge(charge) + } + + items := make([]charges.Charge, 0, len(chargesByID)) + for _, item := range searchResult.Items { + charge, ok := chargesByID[item.ID.ID] + if !ok { + continue + } + + items = append(items, charge) + } + + return pagination.Result[charges.Charge]{ + Page: searchResult.Page, + TotalCount: searchResult.TotalCount, + Items: items, + }, nil +} + +type mockCustomerOverrideService struct { + customer customer.Customer +} + +func (s mockCustomerOverrideService) UpsertCustomerOverride(context.Context, billing.UpsertCustomerOverrideInput) (billing.CustomerOverrideWithDetails, error) { + return billing.CustomerOverrideWithDetails{}, nil +} + +func (s mockCustomerOverrideService) DeleteCustomerOverride(context.Context, billing.DeleteCustomerOverrideInput) error { + return nil +} + +func (s mockCustomerOverrideService) GetCustomerOverride(context.Context, billing.GetCustomerOverrideInput) (billing.CustomerOverrideWithDetails, error) { + return billing.CustomerOverrideWithDetails{ + Customer: &s.customer, + }, nil +} + +func (s mockCustomerOverrideService) GetCustomerApp(context.Context, billing.GetCustomerAppInput) (app.App, error) { + return nil, nil +} + +func (s mockCustomerOverrideService) ListCustomerOverrides(context.Context, billing.ListCustomerOverridesInput) (billing.ListCustomerOverridesResult, error) { + return billing.ListCustomerOverridesResult{}, nil +} + +type mockFeatureConnector struct { + meters feature.FeatureMeterCollection +} + +func (c mockFeatureConnector) CreateFeature(context.Context, feature.CreateFeatureInputs) (feature.Feature, error) { + return feature.Feature{}, nil +} + +func (c mockFeatureConnector) UpdateFeature(context.Context, feature.UpdateFeatureInputs) (feature.Feature, error) { + return feature.Feature{}, nil +} + +func (c mockFeatureConnector) ArchiveFeature(context.Context, models.NamespacedID) error { + return nil +} + +func (c mockFeatureConnector) ListFeatures(context.Context, feature.ListFeaturesParams) (pagination.Result[feature.Feature], error) { + return pagination.Result[feature.Feature]{}, nil +} + +func (c mockFeatureConnector) GetFeature(context.Context, string, string, feature.IncludeArchivedFeature) (*feature.Feature, error) { + return nil, nil +} + +func (c mockFeatureConnector) ResolveFeatureMeters(context.Context, string, []string) (feature.FeatureMeters, error) { + return c.meters, nil +} diff --git a/openmeter/server/router/router.go b/openmeter/server/router/router.go index 8472cfced2..0acef65618 100644 --- a/openmeter/server/router/router.go +++ b/openmeter/server/router/router.go @@ -38,6 +38,7 @@ import ( infohttpdriver "github.com/openmeterio/openmeter/openmeter/info/httpdriver" "github.com/openmeterio/openmeter/openmeter/ingest" ingesthttpdriver "github.com/openmeterio/openmeter/openmeter/ingest/httpdriver" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" meterhttphandler "github.com/openmeterio/openmeter/openmeter/meter/httphandler" @@ -99,9 +100,11 @@ type Config struct { Billing billing.Service BillingInvoicePendingLines billing.InvoicePendingLinesService BillingFeatureSwitches config.BillingFeatureSwitchesConfiguration + Credits config.CreditsConfiguration CurrencyService currencies.CurrencyService CostService cost.Service Customer customer.Service + CustomerBalanceFacade *customerbalance.Facade DebugConnector debug.DebugConnector EntitlementConnector entitlement.Service EntitlementBalanceConnector meteredentitlement.Connector @@ -117,7 +120,6 @@ type Config struct { NamespaceManager *namespace.Manager Notification notification.Service Plan plan.Service - Credits config.CreditsConfiguration PlanAddon planaddon.Service PlanSubscriptionService plansubscription.PlanSubscriptionService PortalCORSEnabled bool diff --git a/openmeter/server/server.go b/openmeter/server/server.go index 1f238463e0..506ed8d047 100644 --- a/openmeter/server/server.go +++ b/openmeter/server/server.go @@ -109,9 +109,11 @@ func NewServer(config *Config) (*Server, error) { BaseURL: "/api/v3", NamespaceDecoder: namespacedriver.StaticNamespaceDecoder(config.RouterConfig.NamespaceManager.GetDefaultNamespace()), ErrorHandler: config.RouterConfig.ErrorHandler, + Credits: config.RouterConfig.Credits, AppService: config.RouterConfig.App, BillingService: config.RouterConfig.Billing, CustomerService: config.RouterConfig.Customer, + CustomerBalanceFacade: config.RouterConfig.CustomerBalanceFacade, CurrencyService: config.RouterConfig.CurrencyService, EntitlementService: config.RouterConfig.EntitlementConnector, IngestService: config.RouterConfig.IngestService,