diff --git a/openmeter/app/custominvoicing/httpdriver/custominvoicing.go b/openmeter/app/custominvoicing/httpdriver/custominvoicing.go index ea520521a7..b700354989 100644 --- a/openmeter/app/custominvoicing/httpdriver/custominvoicing.go +++ b/openmeter/app/custominvoicing/httpdriver/custominvoicing.go @@ -53,7 +53,7 @@ func (h *handler) DraftSyncronized() DraftSyncronizedHandler { return DraftSyncronizedResponse{}, err } - return billinghttpdriver.MapInvoiceToAPI(invoice) + return billinghttpdriver.MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[DraftSyncronizedResponse](http.StatusOK), httptransport.AppendOptions( @@ -104,7 +104,7 @@ func (h *handler) IssuingSyncronized() IssuingSyncronizedHandler { return IssuingSyncronizedResponse{}, err } - return billinghttpdriver.MapInvoiceToAPI(invoice) + return billinghttpdriver.MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[IssuingSyncronizedResponse](http.StatusOK), httptransport.AppendOptions( @@ -160,7 +160,7 @@ func (h *handler) UpdatePaymentStatus() UpdatePaymentStatusHandler { return UpdatePaymentStatusResponse{}, err } - return billinghttpdriver.MapInvoiceToAPI(invoice) + return billinghttpdriver.MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[UpdatePaymentStatusResponse](http.StatusOK), httptransport.AppendOptions( diff --git a/openmeter/app/stripe/service/billing.go b/openmeter/app/stripe/service/billing.go index d4999b9071..ddd277dfc8 100644 --- a/openmeter/app/stripe/service/billing.go +++ b/openmeter/app/stripe/service/billing.go @@ -149,7 +149,7 @@ func stripeErrorToValidationError(stripeErr *stripe.Error) error { // getInvoiceByStripeID retrieves an invoice by its stripe ID, it returns nil if the invoice is not found (thus not managed by the app) func (s *Service) getInvoiceByStripeID(ctx context.Context, appID app.AppID, stripeInvoiceID string) (*billing.StandardInvoice, error) { - invoices, err := s.billingService.ListInvoices(ctx, billing.ListInvoicesInput{ + invoices, err := s.billingService.ListStandardInvoices(ctx, billing.ListStandardInvoicesInput{ Namespaces: []string{appID.Namespace}, ExternalIDs: &billing.ListInvoicesExternalIDFilter{ Type: billing.InvoicingExternalIDType, diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 53fe596f3f..298180f5f5 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -66,20 +66,22 @@ type InvoiceLineAdapter interface { type InvoiceAdapter interface { CreateInvoice(ctx context.Context, input CreateInvoiceAdapterInput) (CreateInvoiceAdapterRespone, error) - GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (StandardInvoice, error) - DeleteGatheringInvoices(ctx context.Context, input DeleteGatheringInvoicesInput) error + GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (Invoice, error) ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error) UpdateInvoice(ctx context.Context, input UpdateInvoiceAdapterInput) (StandardInvoice, error) GetInvoiceOwnership(ctx context.Context, input GetInvoiceOwnershipAdapterInput) (GetOwnershipAdapterResponse, error) + + GetInvoiceType(ctx context.Context, input GetInvoiceTypeAdapterInput) (InvoiceType, error) } type GatheringInvoiceAdapter interface { CreateGatheringInvoice(ctx context.Context, input CreateGatheringInvoiceAdapterInput) (GatheringInvoice, error) UpdateGatheringInvoice(ctx context.Context, input UpdateGatheringInvoiceAdapterInput) error DeleteGatheringInvoice(ctx context.Context, input DeleteGatheringInvoiceAdapterInput) error - GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error) + // TODO: remove + // GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error) ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error) HardDeleteGatheringInvoiceLines(ctx context.Context, invoiceID InvoiceID, lineIDs []string) error diff --git a/openmeter/billing/adapter/gatheringinvoice.go b/openmeter/billing/adapter/gatheringinvoice.go index 5b43bf3eb5..cd154579c9 100644 --- a/openmeter/billing/adapter/gatheringinvoice.go +++ b/openmeter/billing/adapter/gatheringinvoice.go @@ -216,6 +216,14 @@ func (a *adapter) ListGatheringInvoices(ctx context.Context, input billing.ListG query = query.Where(billinginvoice.CurrencyIn(input.Currencies...)) } + if len(input.IDs) > 0 { + query = query.Where(billinginvoice.IDIn(input.IDs...)) + } + + if input.NextCollectionAtBeforeOrEqual != nil { + query = query.Where(billinginvoice.CollectionAtLTE(*input.NextCollectionAtBeforeOrEqual)) + } + order := entutils.GetOrdering(sortx.OrderDefault) if !input.Order.IsDefaultValue() { order = entutils.GetOrdering(input.Order) @@ -303,6 +311,7 @@ func (a *adapter) DeleteGatheringInvoice(ctx context.Context, input billing.Dele invoice, err := tx.db.BillingInvoice.Query(). Where(billinginvoice.ID(input.ID)). Where(billinginvoice.Namespace(input.Namespace)). + Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)). Only(ctx) if err != nil { return err @@ -322,6 +331,8 @@ func (a *adapter) DeleteGatheringInvoice(ctx context.Context, input billing.Dele Where(billinginvoice.ID(input.ID)). Where(billinginvoice.Namespace(input.Namespace)). SetDeletedAt(clock.Now()). + ClearPeriodStart(). + ClearPeriodEnd(). Save(ctx) if err != nil { return err @@ -351,7 +362,8 @@ func (a *adapter) GetGatheringInvoiceById(ctx context.Context, input billing.Get return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GatheringInvoice, error) { query := tx.db.BillingInvoice.Query(). Where(billinginvoice.ID(input.Invoice.ID)). - Where(billinginvoice.Namespace(input.Invoice.Namespace)) + Where(billinginvoice.Namespace(input.Invoice.Namespace)). + Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)) if input.Expand.Has(billing.GatheringInvoiceExpandLines) { query = a.expandGatheringInvoiceLines(query, input.Expand) diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index 51364dfba3..f66269ed7b 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -2,8 +2,8 @@ package billingadapter import ( "context" - "errors" "fmt" + "slices" "strings" "time" @@ -20,7 +20,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoicevalidationissue" "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" "github.com/openmeterio/openmeter/openmeter/streaming" - "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/convert" "github.com/openmeterio/openmeter/pkg/framework/entutils" "github.com/openmeterio/openmeter/pkg/models" @@ -30,7 +29,7 @@ import ( var _ billing.InvoiceAdapter = (*adapter)(nil) -func (a *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdInput) (billing.StandardInvoice, error) { +func (a *adapter) GetStandardInvoiceById(ctx context.Context, in billing.GetInvoiceByIdInput) (billing.StandardInvoice, error) { if err := in.Validate(); err != nil { return billing.StandardInvoice{}, billing.ValidationError{ Err: err, @@ -82,36 +81,6 @@ func (a *adapter) expandInvoiceLineItems(query *db.BillingInvoiceQuery, expand b }) } -func (a *adapter) DeleteGatheringInvoices(ctx context.Context, input billing.DeleteGatheringInvoicesInput) error { - if err := input.Validate(); err != nil { - return billing.ValidationError{ - Err: err, - } - } - - return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { - nAffected, err := tx.db.BillingInvoice.Update(). - Where(billinginvoice.IDIn(input.InvoiceIDs...)). - Where(billinginvoice.Namespace(input.Namespace)). - Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)). - ClearPeriodStart(). - ClearPeriodEnd(). - SetDeletedAt(clock.Now()). - Save(ctx) - if err != nil { - return err - } - - if nAffected != len(input.InvoiceIDs) { - return billing.ValidationError{ - Err: errors.New("invoices failed to delete"), - } - } - - return nil - }) -} - func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesInput) (billing.ListInvoicesResponse, error) { if err := input.Validate(); err != nil { return billing.ListInvoicesResponse{}, billing.ValidationError{ @@ -161,10 +130,6 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn query = query.Where(billinginvoice.CreatedAtLTE(*input.CreatedBefore)) } - if len(input.ExtendedStatuses) > 0 { - query = query.Where(billinginvoice.StatusIn(input.ExtendedStatuses...)) - } - if len(input.IDs) > 0 { query = query.Where(billinginvoice.IDIn(input.IDs...)) } @@ -173,14 +138,46 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn query = query.Where(billinginvoice.DeletedAtIsNil()) } - if len(input.Statuses) > 0 { - query = query.Where(func(s *sql.Selector) { - s.Where(sql.Or( - lo.Map(input.Statuses, func(status string, _ int) *sql.Predicate { - return sql.Like(billinginvoice.FieldStatus, status+"%") - })..., - )) - }) + if len(input.InvoiceTypes) > 0 { + includeStandard := slices.Contains(input.InvoiceTypes, billing.InvoiceTypeStandard) || len(input.InvoiceTypes) == 0 + includeGathering := slices.Contains(input.InvoiceTypes, billing.InvoiceTypeGathering) || len(input.InvoiceTypes) == 0 + + // Hack: right now we are using the statuses to filter for standard and gathering invoices + queries := []predicate.BillingInvoice{} + if includeStandard { + queries = append(queries, func(s *sql.Selector) { + predicates := []*sql.Predicate{ + sql.Not(sql.EQ(billinginvoice.FieldStatus, billing.StandardInvoiceStatusGathering)), + } + + if len(input.StandardInvoiceStatuses) > 0 { + predicates = append(predicates, sql.Or( + lo.Map(input.StandardInvoiceStatuses, func(status string, _ int) *sql.Predicate { + return sql.Like(billinginvoice.FieldStatus, status+"%") + })..., + )) + } + + if len(input.StandardInvoiceExtendedStatuses) > 0 { + predicates = append(predicates, sql.In(billinginvoice.FieldStatus, lo.Map(input.StandardInvoiceExtendedStatuses, func(status billing.StandardInvoiceStatus, _ int) any { + return status + })...)) + } + + s.Where(sql.And(predicates...)) + }) + } + + if includeGathering { + queries = append(queries, func(s *sql.Selector) { + s.Where( + sql.EQ(billinginvoice.FieldStatus, billing.StandardInvoiceStatusGathering)) + }) + } + + query = query.Where(billinginvoice.Or( + queries..., + )) } if len(input.Currencies) > 0 { @@ -248,7 +245,7 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn query = query.Order(billinginvoice.ByCreatedAt(order...)) } - response := pagination.Result[billing.StandardInvoice]{ + response := pagination.Result[billing.Invoice]{ Page: input.Page, } @@ -257,14 +254,33 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn return response, err } - result := make([]billing.StandardInvoice, 0, len(paged.Items)) + result := make([]billing.Invoice, 0, len(paged.Items)) for _, invoice := range paged.Items { - mapped, err := tx.mapStandardInvoiceFromDB(ctx, invoice, input.Expand) + if invoice.Status == billing.StandardInvoiceStatusGathering { + expand := billing.GatheringInvoiceExpands{} + if input.Expand.Lines { + expand = expand.With(billing.GatheringInvoiceExpandLines) + } + + if input.Expand.DeletedLines { + expand = expand.With(billing.GatheringInvoiceExpandDeletedLines) + } + + gatheredMapped, err := tx.mapGatheringInvoiceFromDB(invoice, expand) + if err != nil { + return response, err + } + + result = append(result, billing.NewInvoice(gatheredMapped)) + continue + } + + stdMapped, err := tx.mapStandardInvoiceFromDB(ctx, invoice, input.Expand) if err != nil { return response, err } - result = append(result, mapped) + result = append(result, billing.NewInvoice(stdMapped)) } response.TotalCount = paged.TotalCount @@ -866,3 +882,31 @@ func (a *adapter) IsAppUsed(ctx context.Context, appID app.AppID) error { return nil } + +func (a *adapter) GetInvoiceType(ctx context.Context, input billing.GetInvoiceTypeAdapterInput) (billing.InvoiceType, error) { + if err := input.Validate(); err != nil { + return "", err + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.InvoiceType, error) { + invoice, err := tx.db.BillingInvoice.Query(). + Where(billinginvoice.ID(input.ID)). + Where(billinginvoice.Namespace(input.Namespace)). + Only(ctx) + if err != nil { + if db.IsNotFound(err) { + return "", billing.NotFoundError{ + Err: fmt.Errorf("invoice not found: %w", err), + } + } + + return "", err + } + + if invoice.Status == billing.StandardInvoiceStatusGathering { + return billing.InvoiceTypeGathering, nil + } + + return billing.InvoiceTypeStandard, nil + }) +} diff --git a/openmeter/billing/derived.gen.go b/openmeter/billing/derived.gen.go index 790be53f89..975ba55404 100644 --- a/openmeter/billing/derived.gen.go +++ b/openmeter/billing/derived.gen.go @@ -6,6 +6,28 @@ import ( models "github.com/openmeterio/openmeter/pkg/models" ) +// deriveEqualGatheringLineBase returns whether this and that are equal. +func deriveEqualGatheringLineBase(this, that *GatheringLineBase) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + deriveEqual(&this.ManagedResource, &that.ManagedResource) && + this.Metadata.Equal(that.Metadata) && + this.Annotations.Equal(that.Annotations) && + this.ManagedBy == that.ManagedBy && + this.InvoiceID == that.InvoiceID && + this.Currency == that.Currency && + this.ServicePeriod.Equal(that.ServicePeriod) && + this.InvoiceAt.Equal(that.InvoiceAt) && + this.Price.Equal(&that.Price) && + this.FeatureKey == that.FeatureKey && + this.TaxConfig.Equal(that.TaxConfig) && + deriveEqual_(&this.RateCardDiscounts, &that.RateCardDiscounts) && + ((this.ChildUniqueReferenceID == nil && that.ChildUniqueReferenceID == nil) || (this.ChildUniqueReferenceID != nil && that.ChildUniqueReferenceID != nil && *(this.ChildUniqueReferenceID) == *(that.ChildUniqueReferenceID))) && + deriveEqual_1(this.Subscription, that.Subscription) && + ((this.SplitLineGroupID == nil && that.SplitLineGroupID == nil) || (this.SplitLineGroupID != nil && that.SplitLineGroupID != nil && *(this.SplitLineGroupID) == *(that.SplitLineGroupID))) && + this.UBPConfigID == that.UBPConfigID +} + // deriveEqualDetailedLineBase returns whether this and that are equal. func deriveEqualDetailedLineBase(this, that *DetailedLineBase) bool { return (this == nil && that == nil) || @@ -20,7 +42,7 @@ func deriveEqualDetailedLineBase(this, that *DetailedLineBase) bool { this.Currency == that.Currency && this.PerUnitAmount.Equal(that.PerUnitAmount) && this.Quantity.Equal(that.Quantity) && - deriveEqual_(&this.Totals, &that.Totals) && + deriveEqual_2(&this.Totals, &that.Totals) && this.TaxConfig.Equal(that.TaxConfig) && this.ExternalIDs.Equal(that.ExternalIDs) && this.FeeLineConfigID == that.FeeLineConfigID @@ -33,7 +55,7 @@ func deriveEqualLineDiscountBase(this, that *LineDiscountBase) bool { ((this.Description == nil && that.Description == nil) || (this.Description != nil && that.Description != nil && *(this.Description) == *(that.Description))) && ((this.ChildUniqueReferenceID == nil && that.ChildUniqueReferenceID == nil) || (this.ChildUniqueReferenceID != nil && that.ChildUniqueReferenceID != nil && *(this.ChildUniqueReferenceID) == *(that.ChildUniqueReferenceID))) && this.ExternalIDs.Equal(that.ExternalIDs) && - deriveEqual_1(&this.Reason, &that.Reason) + deriveEqual_3(&this.Reason, &that.Reason) } // deriveEqualAmountLineDiscount returns whether this and that are equal. @@ -86,10 +108,10 @@ func deriveEqualLineBase(this, that *StandardLineBase) bool { ((this.SplitLineGroupID == nil && that.SplitLineGroupID == nil) || (this.SplitLineGroupID != nil && that.SplitLineGroupID != nil && *(this.SplitLineGroupID) == *(that.SplitLineGroupID))) && ((this.ChildUniqueReferenceID == nil && that.ChildUniqueReferenceID == nil) || (this.ChildUniqueReferenceID != nil && that.ChildUniqueReferenceID != nil && *(this.ChildUniqueReferenceID) == *(that.ChildUniqueReferenceID))) && this.TaxConfig.Equal(that.TaxConfig) && - deriveEqual_2(&this.RateCardDiscounts, &that.RateCardDiscounts) && + deriveEqual_(&this.RateCardDiscounts, &that.RateCardDiscounts) && this.ExternalIDs.Equal(that.ExternalIDs) && - deriveEqual_3(this.Subscription, that.Subscription) && - deriveEqual_(&this.Totals, &that.Totals) + deriveEqual_1(this.Subscription, that.Subscription) && + deriveEqual_2(&this.Totals, &that.Totals) } // deriveEqualUsageBasedLine returns whether this and that are equal. @@ -117,41 +139,41 @@ func deriveEqual(this, that *models.ManagedResource) bool { } // deriveEqual_ returns whether this and that are equal. -func deriveEqual_(this, that *Totals) bool { +func deriveEqual_(this, that *Discounts) bool { return (this == nil && that == nil) || this != nil && that != nil && - this.Amount.Equal(that.Amount) && - this.ChargesTotal.Equal(that.ChargesTotal) && - this.DiscountsTotal.Equal(that.DiscountsTotal) && - this.TaxesInclusiveTotal.Equal(that.TaxesInclusiveTotal) && - this.TaxesExclusiveTotal.Equal(that.TaxesExclusiveTotal) && - this.TaxesTotal.Equal(that.TaxesTotal) && - this.Total.Equal(that.Total) + ((this.Percentage == nil && that.Percentage == nil) || (this.Percentage != nil && that.Percentage != nil && (*(this.Percentage)).Equal(*(that.Percentage)))) && + ((this.Usage == nil && that.Usage == nil) || (this.Usage != nil && that.Usage != nil && (*(this.Usage)).Equal(*(that.Usage)))) } // deriveEqual_1 returns whether this and that are equal. -func deriveEqual_1(this, that *DiscountReason) bool { +func deriveEqual_1(this, that *SubscriptionReference) bool { return (this == nil && that == nil) || this != nil && that != nil && - this.t == that.t && - ((this.percentage == nil && that.percentage == nil) || (this.percentage != nil && that.percentage != nil && (*(this.percentage)).Equal(*(that.percentage)))) && - ((this.usage == nil && that.usage == nil) || (this.usage != nil && that.usage != nil && (*(this.usage)).Equal(*(that.usage)))) + this.SubscriptionID == that.SubscriptionID && + this.PhaseID == that.PhaseID && + this.ItemID == that.ItemID && + this.BillingPeriod.Equal(that.BillingPeriod) } // deriveEqual_2 returns whether this and that are equal. -func deriveEqual_2(this, that *Discounts) bool { +func deriveEqual_2(this, that *Totals) bool { return (this == nil && that == nil) || this != nil && that != nil && - ((this.Percentage == nil && that.Percentage == nil) || (this.Percentage != nil && that.Percentage != nil && (*(this.Percentage)).Equal(*(that.Percentage)))) && - ((this.Usage == nil && that.Usage == nil) || (this.Usage != nil && that.Usage != nil && (*(this.Usage)).Equal(*(that.Usage)))) + this.Amount.Equal(that.Amount) && + this.ChargesTotal.Equal(that.ChargesTotal) && + this.DiscountsTotal.Equal(that.DiscountsTotal) && + this.TaxesInclusiveTotal.Equal(that.TaxesInclusiveTotal) && + this.TaxesExclusiveTotal.Equal(that.TaxesExclusiveTotal) && + this.TaxesTotal.Equal(that.TaxesTotal) && + this.Total.Equal(that.Total) } // deriveEqual_3 returns whether this and that are equal. -func deriveEqual_3(this, that *SubscriptionReference) bool { +func deriveEqual_3(this, that *DiscountReason) bool { return (this == nil && that == nil) || this != nil && that != nil && - this.SubscriptionID == that.SubscriptionID && - this.PhaseID == that.PhaseID && - this.ItemID == that.ItemID && - this.BillingPeriod.Equal(that.BillingPeriod) + this.t == that.t && + ((this.percentage == nil && that.percentage == nil) || (this.percentage != nil && that.percentage != nil && (*(this.percentage)).Equal(*(that.percentage)))) && + ((this.usage == nil && that.usage == nil) || (this.usage != nil && that.usage != nil && (*(this.usage)).Equal(*(that.usage)))) } diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index c0cc5a1040..663acb3a48 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -159,6 +159,28 @@ func (g GatheringInvoice) Clone() (GatheringInvoice, error) { return clone, nil } +func (g GatheringInvoice) GetGenericLines() mo.Option[[]GenericInvoiceLine] { + if !g.Lines.IsPresent() { + return mo.None[[]GenericInvoiceLine]() + } + + return mo.Some(lo.Map(g.Lines.OrEmpty(), func(l GatheringLine, _ int) GenericInvoiceLine { + return &gatheringInvoiceLineGenericWrapper{GatheringLine: l} + })) +} + +func (g *GatheringInvoice) SetLines(lines []GenericInvoiceLine) error { + mappedLines, err := slicesx.MapWithErr(lines, func(l GenericInvoiceLine) (GatheringLine, error) { + return l.AsInvoiceLine().AsGatheringLine() + }) + if err != nil { + return fmt.Errorf("mapping lines: %w", err) + } + + g.Lines = NewGatheringInvoiceLines(mappedLines) + return nil +} + type GatheringInvoiceExpand string func (e GatheringInvoiceExpand) Validate() error { @@ -487,6 +509,10 @@ func (i GatheringLineBase) GetChildUniqueReferenceID() *string { return i.ChildUniqueReferenceID } +func (i *GatheringLineBase) SetChildUniqueReferenceID(id *string) { + i.ChildUniqueReferenceID = id +} + func (i GatheringLineBase) GetSplitLineGroupID() *string { return i.SplitLineGroupID } @@ -522,6 +548,18 @@ func (g GatheringLineBase) GetRateCardDiscounts() Discounts { return g.RateCardDiscounts } +func (g GatheringLineBase) Equal(other GatheringLineBase) bool { + return deriveEqualGatheringLineBase(&g, &other) +} + +func (g GatheringLineBase) GetSubscriptionReference() *SubscriptionReference { + if g.Subscription == nil { + return nil + } + + return g.Subscription.Clone() +} + var ( _ GenericInvoiceLine = (*gatheringInvoiceLineGenericWrapper)(nil) _ InvoiceAtAccessor = (*gatheringInvoiceLineGenericWrapper)(nil) @@ -615,6 +653,14 @@ func (g GatheringLine) AsInvoiceLine() InvoiceLine { } } +func (g GatheringLine) Equal(other GatheringLine) bool { + return g.GatheringLineBase.Equal(other.GatheringLineBase) +} + +func (g GatheringLine) RemoveMetaForCompare() (GatheringLine, error) { + return g.WithoutDBState() +} + type CreatePendingInvoiceLinesInput struct { Customer customer.CustomerID `json:"customer"` Currency currencyx.Code `json:"currency"` @@ -706,13 +752,15 @@ type UpdateGatheringInvoiceAdapterInput = GatheringInvoice type ListGatheringInvoicesInput struct { pagination.Page - Namespaces []string - Customers []string - Currencies []currencyx.Code - OrderBy api.InvoiceOrderBy - Order sortx.Order - IncludeDeleted bool - Expand GatheringInvoiceExpands + Namespaces []string + IDs []string + NextCollectionAtBeforeOrEqual *time.Time + Customers []string + Currencies []currencyx.Code + OrderBy api.InvoiceOrderBy + Order sortx.Order + IncludeDeleted bool + Expand GatheringInvoiceExpands } func (i ListGatheringInvoicesInput) Validate() error { @@ -724,10 +772,6 @@ func (i ListGatheringInvoicesInput) Validate() error { } } - if len(i.Namespaces) == 0 { - errs = append(errs, errors.New("namespaces is required")) - } - for _, expand := range i.Expand { if err := expand.Validate(); err != nil { errs = append(errs, fmt.Errorf("expand: %w", err)) diff --git a/openmeter/billing/httpdriver/invoice.go b/openmeter/billing/httpdriver/invoice.go index ebc400ef7f..be8eddab74 100644 --- a/openmeter/billing/httpdriver/invoice.go +++ b/openmeter/billing/httpdriver/invoice.go @@ -92,7 +92,7 @@ func (h *handler) ListInvoices() ListInvoicesHandler { } for _, invoice := range invoices.Items { - invoice, err := MapInvoiceToAPI(invoice) + invoice, err := MapStandardInvoiceToAPI(invoice) if err != nil { return ListInvoicesResponse{}, err } @@ -156,7 +156,7 @@ func (h *handler) InvoicePendingLinesAction() InvoicePendingLinesActionHandler { out := make([]api.Invoice, 0, len(invoices)) for _, invoice := range invoices { - invoice, err := MapInvoiceToAPI(invoice) + invoice, err := MapStandardInvoiceToAPI(invoice) if err != nil { return nil, err } @@ -212,7 +212,7 @@ func (h *handler) GetInvoice() GetInvoiceHandler { return GetInvoiceResponse{}, err } - return MapInvoiceToAPI(invoice) + return MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[GetInvoiceResponse](http.StatusOK), httptransport.AppendOptions( @@ -298,7 +298,7 @@ func (h *handler) ProgressInvoice(action ProgressAction) ProgressInvoiceHandler return ProgressInvoiceResponse{}, err } - return MapInvoiceToAPI(invoice) + return MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[ProgressInvoiceResponse](http.StatusOK), httptransport.AppendOptions( @@ -409,7 +409,7 @@ func (h *handler) SimulateInvoice() SimulateInvoiceHandler { return SimulateInvoiceResponse{}, err } - return MapInvoiceToAPI(invoice) + return MapStandardInvoiceToAPI(invoice) }, commonhttp.JSONResponseEncoderWithStatus[SimulateInvoiceResponse](http.StatusOK), httptransport.AppendOptions( @@ -457,33 +457,52 @@ func (h *handler) UpdateInvoice() UpdateInvoiceHandler { func(ctx context.Context, request UpdateInvoiceRequest) (UpdateInvoiceResponse, error) { invoice, err := h.service.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ Invoice: request.InvoiceID, - EditFn: func(invoice *billing.StandardInvoice) error { + EditFn: func(invoice billing.Invoice) (billing.Invoice, error) { var err error - invoice.Supplier = mergeInvoiceSupplierFromAPI(invoice.Supplier, request.Input.Supplier) - invoice.Customer = mergeInvoiceCustomerFromAPI(invoice.Customer, request.Input.Customer) - invoice.Workflow, err = mergeInvoiceWorkflowFromAPI(invoice.Workflow, request.Input.Workflow) + if invoice.Type() == billing.InvoiceTypeGathering { + gatheringInvoice, err := invoice.AsGatheringInvoice() + if err != nil { + return billing.Invoice{}, fmt.Errorf("converting invoice to gathering invoice: %w", err) + } + + gatheringInvoice.Lines, err = h.mergeGatheringInvoiceLinesFromAPI(ctx, &gatheringInvoice, request.Input.Lines) + if err != nil { + return billing.Invoice{}, fmt.Errorf("merging lines: %w", err) + } + + return billing.NewInvoice(gatheringInvoice), nil + } + + stdInvoice, err := invoice.AsStandardInvoice() if err != nil { - return err + return billing.Invoice{}, fmt.Errorf("converting invoice to standard invoice: %w", err) } - invoice.Lines, err = h.mergeInvoiceLinesFromAPI(ctx, invoice, request.Input.Lines) + stdInvoice.Supplier = mergeInvoiceSupplierFromAPI(stdInvoice.Supplier, request.Input.Supplier) + stdInvoice.Customer = mergeInvoiceCustomerFromAPI(stdInvoice.Customer, request.Input.Customer) + stdInvoice.Workflow, err = mergeInvoiceWorkflowFromAPI(stdInvoice.Workflow, request.Input.Workflow) if err != nil { - return err + return billing.Invoice{}, fmt.Errorf("merging workflow: %w", err) + } + + stdInvoice.Lines, err = h.mergeStandardInvoiceLinesFromAPI(ctx, &stdInvoice, request.Input.Lines) + if err != nil { + return billing.Invoice{}, fmt.Errorf("merging lines: %w", err) } // basic fields - invoice.Description = request.Input.Description - invoice.Metadata = lo.FromPtrOr(request.Input.Metadata, map[string]string{}) + stdInvoice.Description = request.Input.Description + stdInvoice.Metadata = lo.FromPtrOr(request.Input.Metadata, map[string]string{}) - return nil + return billing.NewInvoice(stdInvoice), nil }, }) if err != nil { return UpdateInvoiceResponse{}, err } - return MapInvoiceToAPI(invoice) + return h.MapInvoiceToAPI(ctx, invoice) }, commonhttp.JSONResponseEncoderWithStatus[UpdateInvoiceResponse](http.StatusOK), httptransport.AppendOptions( @@ -494,7 +513,40 @@ func (h *handler) UpdateInvoice() UpdateInvoiceHandler { ) } -func MapInvoiceToAPI(invoice billing.StandardInvoice) (api.Invoice, error) { +func (h *handler) MapInvoiceToAPI(ctx context.Context, invoice billing.Invoice) (api.Invoice, error) { + switch invoice.Type() { + case billing.InvoiceTypeStandard: + standardInvoice, err := invoice.AsStandardInvoice() + if err != nil { + return api.Invoice{}, fmt.Errorf("converting invoice to standard invoice: %w", err) + } + + return MapStandardInvoiceToAPI(standardInvoice) + case billing.InvoiceTypeGathering: + gatheringInvoice, err := invoice.AsGatheringInvoice() + if err != nil { + return api.Invoice{}, fmt.Errorf("converting invoice to gathering invoice: %w", err) + } + + // TODO: For the V3 api let's make sure that we don't return gathering invoice customer data (or even gathering invoices) + mergedProfile, err := h.service.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{ + Customer: gatheringInvoice.GetCustomerID(), + Expand: billing.CustomerOverrideExpand{ + Customer: true, + Apps: true, + }, + }) + if err != nil { + return UpdateInvoiceResponse{}, fmt.Errorf("failed to get customer override: %w", err) + } + + return MapGatheringInvoiceToAPI(gatheringInvoice, mergedProfile.Customer, mergedProfile.MergedProfile) + default: + return api.Invoice{}, fmt.Errorf("invalid invoice type: %s", invoice.Type()) + } +} + +func MapStandardInvoiceToAPI(invoice billing.StandardInvoice) (api.Invoice, error) { var apps *api.BillingProfileAppsOrReference var err error @@ -611,7 +663,7 @@ func MapEventInvoiceToAPI(event billing.EventStandardInvoice) (api.Invoice, erro // Prefer the apps from the event event.Invoice.Workflow.Apps = nil - invoice, err := MapInvoiceToAPI(event.Invoice) + invoice, err := MapStandardInvoiceToAPI(event.Invoice) if err != nil { return api.Invoice{}, err } diff --git a/openmeter/billing/httpdriver/invoice_test.go b/openmeter/billing/httpdriver/invoice_test.go index e4d4209f8c..67d9e68cc2 100644 --- a/openmeter/billing/httpdriver/invoice_test.go +++ b/openmeter/billing/httpdriver/invoice_test.go @@ -78,7 +78,7 @@ func (s *InvoicingTestSuite) TestGatheringInvoiceSerialization() { s.NoError(err) // Let's serialize the invoice - apiInvoice, err := MapInvoiceToAPI(invoice) + apiInvoice, err := MapStandardInvoiceToAPI(invoice) s.NoError(err) // Let's deserialize the invoice diff --git a/openmeter/billing/httpdriver/invoiceline.go b/openmeter/billing/httpdriver/invoiceline.go index b61b9d1923..8c6515c514 100644 --- a/openmeter/billing/httpdriver/invoiceline.go +++ b/openmeter/billing/httpdriver/invoiceline.go @@ -634,7 +634,7 @@ func mapSimulationLineToEntity(line api.InvoiceSimulationLine) (*billing.Standar }, nil } -func lineFromInvoiceLineReplaceUpdate(line api.InvoiceLineReplaceUpdate, invoice *billing.StandardInvoice) (*billing.StandardLine, error) { +func standardLineFromInvoiceLineReplaceUpdate(line api.InvoiceLineReplaceUpdate, invoice *billing.StandardInvoice) (*billing.StandardLine, error) { rateCardParsed, err := mapAndValidateInvoiceLineRateCardDeprecatedFields(invoiceLineRateCardItems{ RateCard: line.RateCard, Price: line.Price, @@ -677,7 +677,52 @@ func lineFromInvoiceLineReplaceUpdate(line api.InvoiceLineReplaceUpdate, invoice }, nil } -func mergeLineFromInvoiceLineReplaceUpdate(existing *billing.StandardLine, line api.InvoiceLineReplaceUpdate) (*billing.StandardLine, bool, error) { +func gatheringLineFromInvoiceLineReplaceUpdate(line api.InvoiceLineReplaceUpdate, invoice *billing.GatheringInvoice) (billing.GatheringLine, error) { + rateCardParsed, err := mapAndValidateInvoiceLineRateCardDeprecatedFields(invoiceLineRateCardItems{ + RateCard: line.RateCard, + Price: line.Price, + TaxConfig: line.TaxConfig, + FeatureKey: line.FeatureKey, + }) + if err != nil { + return billing.GatheringLine{}, fmt.Errorf("failed to map usage based line: %w", err) + } + + if rateCardParsed.Price == nil { + return billing.GatheringLine{}, billing.ValidationError{ + Err: fmt.Errorf("price is required for usage based lines"), + } + } + + return billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: invoice.Namespace, + Name: line.Name, + Description: line.Description, + }), + + Metadata: lo.FromPtrOr(line.Metadata, map[string]string{}), + ManagedBy: billing.ManuallyManagedLine, + + InvoiceID: invoice.ID, + Currency: invoice.Currency, + + ServicePeriod: timeutil.ClosedPeriod{ + From: line.Period.From.Truncate(streaming.MinimumWindowSizeDuration), + To: line.Period.To.Truncate(streaming.MinimumWindowSizeDuration), + }, + InvoiceAt: line.InvoiceAt.Truncate(streaming.MinimumWindowSizeDuration), + + TaxConfig: rateCardParsed.TaxConfig, + RateCardDiscounts: rateCardParsed.Discounts, + Price: lo.FromPtr(rateCardParsed.Price), + FeatureKey: rateCardParsed.FeatureKey, + }, + }, nil +} + +func mergeStandardLineFromInvoiceLineReplaceUpdate(existing *billing.StandardLine, line api.InvoiceLineReplaceUpdate) (*billing.StandardLine, bool, error) { oldBase := existing.StandardLineBase.Clone() oldUBP := existing.UsageBased.Clone() @@ -742,7 +787,79 @@ func mergeLineFromInvoiceLineReplaceUpdate(existing *billing.StandardLine, line return existing, wasChange, nil } -func (h *handler) mergeInvoiceLinesFromAPI(ctx context.Context, invoice *billing.StandardInvoice, updatedLines []api.InvoiceLineReplaceUpdate) (billing.StandardInvoiceLines, error) { +func mergeGatheringLineFromInvoiceLineReplaceUpdate(existing billing.GatheringLine, line api.InvoiceLineReplaceUpdate) (billing.GatheringLine, error) { + old, err := existing.Clone() + if err != nil { + return billing.GatheringLine{}, fmt.Errorf("cloning existing line: %w", err) + } + + rateCardParsed, err := mapAndValidateInvoiceLineRateCardDeprecatedFields(invoiceLineRateCardItems{ + RateCard: line.RateCard, + Price: line.Price, + TaxConfig: line.TaxConfig, + FeatureKey: line.FeatureKey, + }) + if err != nil { + return billing.GatheringLine{}, billing.ValidationError{ + Err: fmt.Errorf("failed to map usage based line: %w", err), + } + } + + if line.Price == nil { + return billing.GatheringLine{}, billing.ValidationError{ + Err: fmt.Errorf("price is required for usage based lines"), + } + } + + existing.Metadata = lo.FromPtrOr(line.Metadata, api.Metadata(existing.Metadata)) + existing.Name = line.Name + existing.Description = line.Description + + existing.ServicePeriod.From = line.Period.From.Truncate(streaming.MinimumWindowSizeDuration) + existing.ServicePeriod.To = line.Period.To.Truncate(streaming.MinimumWindowSizeDuration) + existing.InvoiceAt = line.InvoiceAt.Truncate(streaming.MinimumWindowSizeDuration) + + existing.TaxConfig = rateCardParsed.TaxConfig + existing.Price = lo.FromPtr(rateCardParsed.Price) + existing.FeatureKey = rateCardParsed.FeatureKey + + // Rate card discounts are not allowed to be updated on a progressively billed line (e.g. if there is + // already a partial invoice created), as we might go short on the discount quantity. + // + // If this is ever requested: + // - we should introduce the concept of a "discount pool" that is shared across invoices and + // - editing the discount edits the pool + // - editing requires that the discount pool's quantity cannot be less than the already used + // quantity. + + if existing.SplitLineGroupID != nil && rateCardParsed.Discounts.Usage != nil && existing.RateCardDiscounts.Usage != nil { + if !equal.PtrEqual(rateCardParsed.Discounts.Usage, existing.RateCardDiscounts.Usage) { + return billing.GatheringLine{}, billing.ValidationError{ + Err: fmt.Errorf("line[%s]: %w", existing.ID, billing.ErrInvoiceLineProgressiveBillingUsageDiscountUpdateForbidden), + } + } + } + + existing.RateCardDiscounts = rateCardParsed.Discounts + + if !old.Equal(existing) { + existing.ManagedBy = billing.ManuallyManagedLine + } + + // We are not allowing period change for split lines (or their children), as that would mess up the + // calculation logic and/or we would need to update multiple invoices to correct all the references. + // + // Deletion is allowed. + if old.SplitLineGroupID != nil && !old.ServicePeriod.Equal(existing.ServicePeriod) { + return billing.GatheringLine{}, billing.ValidationError{ + Err: fmt.Errorf("line[%s]: %w", existing.ID, billing.ErrInvoiceLineNoPeriodChangeForSplitLine), + } + } + + return existing, nil +} + +func (h *handler) mergeStandardInvoiceLinesFromAPI(ctx context.Context, invoice *billing.StandardInvoice, updatedLines []api.InvoiceLineReplaceUpdate) (billing.StandardInvoiceLines, error) { linesByID, _ := slicesx.UniqueGroupBy(invoice.Lines.OrEmpty(), func(line *billing.StandardLine) string { return line.ID }) @@ -759,7 +876,7 @@ func (h *handler) mergeInvoiceLinesFromAPI(ctx context.Context, invoice *billing if id == "" || !existingLineFound { // We allow injecting fake IDs for new lines, so that discounts can reference those, // but we are not persisting them to the database - newLine, err := lineFromInvoiceLineReplaceUpdate(line, invoice) + newLine, err := standardLineFromInvoiceLineReplaceUpdate(line, invoice) if err != nil { return billing.StandardInvoiceLines{}, fmt.Errorf("failed to create new line: %w", err) } @@ -779,7 +896,7 @@ func (h *handler) mergeInvoiceLinesFromAPI(ctx context.Context, invoice *billing } foundLines.Add(id) - mergedLine, changed, err := mergeLineFromInvoiceLineReplaceUpdate(existingLine, line) + mergedLine, changed, err := mergeStandardLineFromInvoiceLineReplaceUpdate(existingLine, line) if err != nil { return billing.StandardInvoiceLines{}, fmt.Errorf("failed to merge line: %w", err) } @@ -808,3 +925,50 @@ func (h *handler) mergeInvoiceLinesFromAPI(ctx context.Context, invoice *billing return billing.NewStandardInvoiceLines(out), nil } + +func (h *handler) mergeGatheringInvoiceLinesFromAPI(ctx context.Context, invoice *billing.GatheringInvoice, updatedLines []api.InvoiceLineReplaceUpdate) (billing.GatheringInvoiceLines, error) { + linesByID, _ := slicesx.UniqueGroupBy(invoice.Lines.OrEmpty(), func(line billing.GatheringLine) string { + return line.ID + }) + + foundLines := set.New[string]() + + out := make([]billing.GatheringLine, 0, len(updatedLines)) + + for _, line := range updatedLines { + id := lo.FromPtr(line.Id) + + existingLine, existingLineFound := linesByID[id] + + if id == "" || !existingLineFound { + // We allow injecting fake IDs for new lines, so that discounts can reference those, + // but we are not persisting them to the database + newLine, err := gatheringLineFromInvoiceLineReplaceUpdate(line, invoice) + if err != nil { + return billing.GatheringInvoiceLines{}, fmt.Errorf("failed to create new line: %w", err) + } + + out = append(out, newLine) + continue + } + + foundLines.Add(id) + mergedLine, err := mergeGatheringLineFromInvoiceLineReplaceUpdate(existingLine, line) + if err != nil { + return billing.GatheringInvoiceLines{}, fmt.Errorf("failed to merge line: %w", err) + } + + out = append(out, mergedLine) + } + + lineIDs := set.New(lo.Keys(linesByID)...) + + deletedLines := set.Subtract(lineIDs, foundLines).AsSlice() + for _, id := range deletedLines { + existingLine := linesByID[id] + existingLine.DeletedAt = lo.ToPtr(clock.Now()) + out = append(out, existingLine) + } + + return billing.NewGatheringInvoiceLines(out), nil +} diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index d4763db7fc..2efc4774fc 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -50,12 +50,18 @@ func (i InvoiceID) Validate() error { type GenericInvoice interface { GenericInvoiceReader + + SetLines(lines []GenericInvoiceLine) error } type GenericInvoiceReader interface { GetDeletedAt() *time.Time GetID() string GetInvoiceID() InvoiceID + GetCustomerID() customer.CustomerID + + // GetGenericLines returns the lines of the invoice as generic lines. + GetGenericLines() mo.Option[[]GenericInvoiceLine] AsInvoice() Invoice } @@ -159,6 +165,35 @@ func (i Invoice) AsGatheringInvoice() (GatheringInvoice, error) { return *i.gatheringInvoice, nil } +func (i Invoice) AsGenericInvoice() (GenericInvoice, error) { + switch i.t { + case InvoiceTypeStandard: + if i.standardInvoice == nil { + return nil, fmt.Errorf("standard invoice is nil") + } + + cloned, err := i.standardInvoice.Clone() + if err != nil { + return nil, err + } + + return &cloned, nil + case InvoiceTypeGathering: + if i.gatheringInvoice == nil { + return nil, fmt.Errorf("gathering invoice is nil") + } + + cloned, err := i.gatheringInvoice.Clone() + if err != nil { + return nil, err + } + + return &cloned, nil + default: + return nil, fmt.Errorf("invalid invoice type: %s", i.t) + } +} + func (i Invoice) Validate() error { switch i.t { case InvoiceTypeStandard: @@ -213,7 +248,6 @@ func (i genericMultiInvoiceInput) Validate() error { } type ( - DeleteGatheringInvoicesInput = genericMultiInvoiceInput LockInvoicesForUpdateInput = genericMultiInvoiceInput AssociatedLineCountsAdapterInput = genericMultiInvoiceInput ) @@ -283,14 +317,16 @@ type ListInvoicesInput struct { Namespaces []string IDs []string Customers []string - // Statuses searches by short InvoiceStatus (e.g. draft, issued) - Statuses []string + + // StandardInvoiceStatuses searches by short StandardInvoiceStatus (e.g. draft, issued) + StandardInvoiceStatuses []string + // StandardInvoiceExtendedStatuses searches by exact StandardInvoiceStatus + StandardInvoiceExtendedStatuses []StandardInvoiceStatus + InvoiceTypes []InvoiceType HasAvailableAction []InvoiceAvailableActionsFilter - // ExtendedStatuses searches by exact InvoiceStatus - ExtendedStatuses []StandardInvoiceStatus - Currencies []currencyx.Code + Currencies []currencyx.Code IssuedAfter *time.Time IssuedBefore *time.Time @@ -345,6 +381,27 @@ func (i ListInvoicesInput) Validate() error { } } + if len(i.InvoiceTypes) > 0 { + errs := errors.Join( + lo.Map(i.InvoiceTypes, func(invoiceType InvoiceType, _ int) error { + return invoiceType.Validate() + })..., + ) + if errs != nil { + outErr = append(outErr, errs) + } + } + + willListStandardInvoices := len(i.InvoiceTypes) == 0 || slices.Contains(i.InvoiceTypes, InvoiceTypeStandard) + if !willListStandardInvoices { + if len(i.StandardInvoiceStatuses) > 0 { + outErr = append(outErr, errors.New("standard invoice statuses are not supported when listing non-standard invoices")) + } + if len(i.StandardInvoiceExtendedStatuses) > 0 { + outErr = append(outErr, errors.New("standard invoice extended statuses are not supported when listing non-standard invoices")) + } + } + if len(i.HasAvailableAction) > 0 { errs := errors.Join( lo.Map(i.HasAvailableAction, func(action InvoiceAvailableActionsFilter, _ int) error { @@ -359,7 +416,7 @@ func (i ListInvoicesInput) Validate() error { return errors.Join(outErr...) } -type ListInvoicesResponse = pagination.Result[StandardInvoice] +type ListInvoicesResponse = pagination.Result[Invoice] type InvoicePendingLinesInput struct { Customer customer.CustomerID @@ -389,3 +446,26 @@ func (i InvoicePendingLinesInput) Validate() error { return nil } + +type UpdateInvoiceInput struct { + Invoice InvoiceID + EditFn func(Invoice) (Invoice, error) + // IncludeDeletedLines signals the update to populate the deleted lines into the lines field, for the edit function + IncludeDeletedLines bool +} + +func (i UpdateInvoiceInput) Validate() error { + var outErr []error + + if err := i.Invoice.Validate(); err != nil { + outErr = append(outErr, fmt.Errorf("id: %w", err)) + } + + if i.EditFn == nil { + outErr = append(outErr, errors.New("edit function is required")) + } + + return errors.Join(outErr...) +} + +type GetInvoiceTypeAdapterInput = InvoiceID diff --git a/openmeter/billing/invoiceline.go b/openmeter/billing/invoiceline.go index 97e30cd51f..c2f3fc8a66 100644 --- a/openmeter/billing/invoiceline.go +++ b/openmeter/billing/invoiceline.go @@ -6,6 +6,8 @@ import ( "slices" "time" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/pkg/models" timeutil "github.com/openmeterio/openmeter/pkg/timeutil" @@ -121,6 +123,7 @@ type GenericInvoiceLine interface { SetDeletedAt(at *time.Time) SetPrice(price productcatalog.Price) UpdateServicePeriod(func(p *timeutil.ClosedPeriod)) + SetChildUniqueReferenceID(id *string) } // GenericInvoiceLineReader is an interface that provides access to the generic invoice fields. @@ -139,6 +142,8 @@ type GenericInvoiceLineReader interface { Validate() error AsInvoiceLine() InvoiceLine GetRateCardDiscounts() Discounts + GetSubscriptionReference() *SubscriptionReference + GetSplitLineGroupID() *string } type InvoiceAtAccessor interface { @@ -146,6 +151,10 @@ type InvoiceAtAccessor interface { SetInvoiceAt(at time.Time) } +type QuantityAccessor interface { + GetQuantity() *alpacadecimal.Decimal +} + type InvoiceLineType string const ( diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index bbd6f036db..a96d1b3d76 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -15,6 +15,7 @@ type Service interface { SplitLineGroupService InvoiceService GatheringInvoiceService + StandardInvoiceService SequenceService LockableService @@ -60,17 +61,6 @@ type InvoiceService interface { ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) GetInvoiceByID(ctx context.Context, input GetInvoiceByIdInput) (StandardInvoice, error) InvoicePendingLines(ctx context.Context, input InvoicePendingLinesInput) ([]StandardInvoice, error) - // AdvanceInvoice advances the invoice to the next stage, the advancement is stopped until: - // - an error is occurred - // - the invoice is in a state that cannot be advanced (e.g. waiting for draft period to expire) - // - the invoice is advanced to the final state - AdvanceInvoice(ctx context.Context, input AdvanceInvoiceInput) (StandardInvoice, error) - SnapshotQuantities(ctx context.Context, input SnapshotQuantitiesInput) (StandardInvoice, error) - ApproveInvoice(ctx context.Context, input ApproveInvoiceInput) (StandardInvoice, error) - RetryInvoice(ctx context.Context, input RetryInvoiceInput) (StandardInvoice, error) - DeleteInvoice(ctx context.Context, input DeleteInvoiceInput) (StandardInvoice, error) - // UpdateInvoice updates an invoice as a whole - UpdateInvoice(ctx context.Context, input UpdateInvoiceInput) (StandardInvoice, error) // SimulateInvoice generates an invoice based on the provided input, but does not persist it // can be used to execute the invoice generation logic without actually creating an invoice in the database @@ -84,11 +74,33 @@ type InvoiceService interface { RecalculateGatheringInvoices(ctx context.Context, input RecalculateGatheringInvoicesInput) error } +type StandardInvoiceService interface { + // UpdateStandardInvoice updates a standard invoice as a whole + UpdateStandardInvoice(ctx context.Context, input UpdateStandardInvoiceInput) (StandardInvoice, error) + + // ListStandardInvoices lists standard invoices for a given customer + ListStandardInvoices(ctx context.Context, input ListStandardInvoicesInput) (ListStandardInvoicesResponse, error) + + // AdvanceInvoice advances the invoice to the next stage, the advancement is stopped until: + // - an error is occurred + // - the invoice is in a state that cannot be advanced (e.g. waiting for draft period to expire) + // - the invoice is advanced to the final state + AdvanceInvoice(ctx context.Context, input AdvanceInvoiceInput) (StandardInvoice, error) + SnapshotQuantities(ctx context.Context, input SnapshotQuantitiesInput) (StandardInvoice, error) + ApproveInvoice(ctx context.Context, input ApproveInvoiceInput) (StandardInvoice, error) + RetryInvoice(ctx context.Context, input RetryInvoiceInput) (StandardInvoice, error) + DeleteInvoice(ctx context.Context, input DeleteInvoiceInput) (StandardInvoice, error) + UpdateInvoice(ctx context.Context, input UpdateInvoiceInput) (Invoice, error) + + GetStandardInvoiceById(ctx context.Context, input GetStandardInvoiceByIdInput) (StandardInvoice, error) +} + type GatheringInvoiceService interface { // CreatePendingInvoiceLines creates pending invoice lines for a customer, if the lines are zero valued, the response is nil CreatePendingInvoiceLines(ctx context.Context, input CreatePendingInvoiceLinesInput) (*CreatePendingInvoiceLinesResult, error) ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error) + GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error) UpdateGatheringInvoice(ctx context.Context, input UpdateGatheringInvoiceInput) error } diff --git a/openmeter/billing/service/gatheringinvoice.go b/openmeter/billing/service/gatheringinvoice.go index 29ad9e87f9..2e0fbc3072 100644 --- a/openmeter/billing/service/gatheringinvoice.go +++ b/openmeter/billing/service/gatheringinvoice.go @@ -26,7 +26,7 @@ func (s *Service) ListGatheringInvoices(ctx context.Context, input billing.ListG }) } -func (s *Service) emulateStandardInvoicesGatheringInvoiceFields(ctx context.Context, invoices []billing.StandardInvoice) ([]billing.StandardInvoice, error) { +func (s *Service) emulateStandardInvoicesGatheringInvoiceFields(ctx context.Context, invoices []billing.GatheringInvoice) ([]billing.StandardInvoice, error) { mergedProfiles := make(map[customer.CustomerID]billing.CustomerOverrideWithDetails) for idx := range invoices { @@ -163,10 +163,7 @@ func (s *Service) UpdateGatheringInvoice(ctx context.Context, input billing.Upda // TransactionForGatheringInvoiceManipulation if invoice.Lines.NonDeletedLineCount() == 0 { - if err := s.adapter.DeleteGatheringInvoices(ctx, billing.DeleteGatheringInvoicesInput{ - Namespace: input.Invoice.Namespace, - InvoiceIDs: []string{invoice.ID}, - }); err != nil { + if err := s.adapter.DeleteGatheringInvoice(ctx, invoice.GetInvoiceID()); err != nil { return fmt.Errorf("deleting gathering invoice: %w", err) } } @@ -205,3 +202,13 @@ func (s Service) checkIfGatheringLinesAreInvoicable(ctx context.Context, invoice })..., ) } + +func (s *Service) GetGatheringInvoiceById(ctx context.Context, input billing.GetGatheringInvoiceByIdInput) (billing.GatheringInvoice, error) { + if err := input.Validate(); err != nil { + return billing.GatheringInvoice{}, err + } + + return transaction.Run(ctx, s.adapter, func(ctx context.Context) (billing.GatheringInvoice, error) { + return s.adapter.GetGatheringInvoiceById(ctx, input) + }) +} diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 8c8c9bf8f1..7cdbda83d0 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -536,111 +536,6 @@ func (s *Service) DeleteInvoice(ctx context.Context, input billing.DeleteInvoice return s.executeTriggerOnInvoice(ctx, input, billing.TriggerDelete) } -func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoiceInput) (billing.StandardInvoice, error) { - if err := input.Validate(); err != nil { - return billing.StandardInvoice{}, billing.ValidationError{ - Err: err, - } - } - - invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ - Invoice: input.Invoice, - Expand: billing.InvoiceExpand{}, // We don't want to expand anything as we will have to refetch the invoice anyway - }) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("fetching invoice: %w", err) - } - - if invoice.Status == billing.StandardInvoiceStatusGathering { - customerProfile, err := s.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{ - Customer: invoice.CustomerID(), - }) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("fetching profile: %w", err) - } - - return transactionForInvoiceManipulation(ctx, s, invoice.CustomerID(), func(ctx context.Context) (billing.StandardInvoice, error) { - invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ - Invoice: input.Invoice, - Expand: billing.InvoiceExpandAll. - SetDeletedLines(input.IncludeDeletedLines), - }) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("fetching invoice: %w", err) - } - - if err := input.EditFn(&invoice); err != nil { - return billing.StandardInvoice{}, fmt.Errorf("editing invoice: %w", err) - } - - invoice.Lines, err = invoice.Lines.WithNormalizedValues() - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("normalizing lines: %w", err) - } - - if err := s.invoiceCalculator.CalculateLegacyGatheringInvoice(&invoice); err != nil { - return billing.StandardInvoice{}, fmt.Errorf("calculating invoice[%s]: %w", invoice.ID, err) - } - - if err := invoice.Validate(); err != nil { - return billing.StandardInvoice{}, billing.ValidationError{ - Err: err, - } - } - - featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Namespace, invoice.Lines) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) - } - - // Check if the new lines are still invoicable - if err := s.checkIfLinesAreInvoicable(ctx, &invoice, customerProfile.MergedProfile.WorkflowConfig.Invoicing.ProgressiveBilling, featureMeters); err != nil { - return billing.StandardInvoice{}, err - } - - invoice, err = s.updateInvoice(ctx, invoice) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("updating invoice[%s]: %w", input.Invoice.ID, err) - } - - // Auto delete the invoice if it has no lines, this needs to happen here, as we are in a - // TranscationForGatheringInvoiceManipulation - - if invoice.Lines.NonDeletedLineCount() == 0 { - if err := s.adapter.DeleteGatheringInvoices(ctx, billing.DeleteGatheringInvoicesInput{ - Namespace: input.Invoice.Namespace, - InvoiceIDs: []string{invoice.ID}, - }); err != nil { - return billing.StandardInvoice{}, fmt.Errorf("deleting gathering invoice: %w", err) - } - } - - return invoice, nil - }) - } - - return s.executeTriggerOnInvoice( - ctx, - input.Invoice, - billing.TriggerUpdated, - ExecuteTriggerWithIncludeDeletedLines(input.IncludeDeletedLines), - ExecuteTriggerWithAllowInStates(billing.StandardInvoiceStatusDraftUpdating), - ExecuteTriggerWithEditCallback(func(sm *InvoiceStateMachine) error { - if err := input.EditFn(&sm.Invoice); err != nil { - return fmt.Errorf("editing invoice: %w", err) - } - - if err := sm.Invoice.Validate(); err != nil { - return billing.ValidationError{ - Err: err, - } - } - - return nil - }), - ) -} - // updateInvoice calls the adapter to update the invoice and returns the updated invoice including any expands that are // the responsibility of the service func (s Service) updateInvoice(ctx context.Context, in billing.UpdateInvoiceAdapterInput) (billing.StandardInvoice, error) { @@ -921,3 +816,95 @@ func (s *Service) RecalculateGatheringInvoices(ctx context.Context, input billin return nil }) } + +func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoiceInput) (billing.Invoice, error) { + if err := input.Validate(); err != nil { + return billing.Invoice{}, billing.ValidationError{ + Err: err, + } + } + return transaction.Run(ctx, s.adapter, func(ctx context.Context) (billing.Invoice, error) { + invoiceType, err := s.adapter.GetInvoiceType(ctx, input.Invoice) + if err != nil { + return billing.Invoice{}, fmt.Errorf("getting invoice type: %w", err) + } + + if invoiceType == billing.InvoiceTypeGathering { + err := s.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: input.Invoice, + IncludeDeletedLines: input.IncludeDeletedLines, + EditFn: func(invoice *billing.GatheringInvoice) error { + if invoice == nil { + return fmt.Errorf("invoice is nil") + } + + editedInvoice, err := input.EditFn(billing.NewInvoice(*invoice)) + if err != nil { + return fmt.Errorf("editing invoice: %w", err) + } + + editedGatheringInvoice, err := editedInvoice.AsGatheringInvoice() + if err != nil { + return fmt.Errorf("converting invoice to gathering invoice: %w", err) + } + + *invoice = editedGatheringInvoice + + return nil + }, + }) + if err != nil { + return billing.Invoice{}, fmt.Errorf("updating gathering invoice: %w", err) + } + + expand := billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandAvailableActions, + } + + if input.IncludeDeletedLines { + expand = expand.With(billing.GatheringInvoiceExpandDeletedLines) + } + + gatheringInvoice, err := s.adapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: input.Invoice, + Expand: expand, + }) + if err != nil { + return billing.Invoice{}, fmt.Errorf("fetching gathering invoice: %w", err) + } + + return billing.NewInvoice(gatheringInvoice), nil + } + + // Standard invoice + standardInvoice, err := s.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: input.Invoice, + IncludeDeletedLines: input.IncludeDeletedLines, + EditFn: func(invoice *billing.StandardInvoice) error { + if invoice == nil { + return fmt.Errorf("invoice is nil") + } + + editedInvoice, err := input.EditFn(billing.NewInvoice(*invoice)) + if err != nil { + return fmt.Errorf("editing invoice: %w", err) + } + + editedStandardInvoice, err := editedInvoice.AsStandardInvoice() + if err != nil { + return fmt.Errorf("converting invoice to standard invoice: %w", err) + } + + *invoice = editedStandardInvoice + + return nil + }, + }) + if err != nil { + return billing.Invoice{}, fmt.Errorf("updating standard invoice: %w", err) + } + + return billing.NewInvoice(standardInvoice), nil + }) +} diff --git a/openmeter/billing/service/stdinvoice.go b/openmeter/billing/service/stdinvoice.go new file mode 100644 index 0000000000..ce3336241c --- /dev/null +++ b/openmeter/billing/service/stdinvoice.go @@ -0,0 +1,102 @@ +package billingservice + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing" +) + +var _ billing.StandardInvoiceService = (*Service)(nil) + +func (s *Service) UpdateStandardInvoice(ctx context.Context, input billing.UpdateStandardInvoiceInput) (billing.StandardInvoice, error) { + if err := input.Validate(); err != nil { + return billing.StandardInvoice{}, billing.ValidationError{ + Err: err, + } + } + + invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: input.Invoice, + Expand: billing.InvoiceExpand{}, // We don't want to expand anything as we will have to refetch the invoice anyway + }) + if err != nil { + return billing.StandardInvoice{}, fmt.Errorf("fetching invoice: %w", err) + } + + if invoice.Status == billing.StandardInvoiceStatusGathering { + return billing.StandardInvoice{}, billing.ValidationError{ + Err: fmt.Errorf("invoice[%s] is a gathering invoice, cannot be updated via the standard invoice service", invoice.ID), + } + } + + return s.executeTriggerOnInvoice( + ctx, + input.Invoice, + billing.TriggerUpdated, + ExecuteTriggerWithIncludeDeletedLines(input.IncludeDeletedLines), + ExecuteTriggerWithAllowInStates(billing.StandardInvoiceStatusDraftUpdating), + ExecuteTriggerWithEditCallback(func(sm *InvoiceStateMachine) error { + if err := input.EditFn(&sm.Invoice); err != nil { + return fmt.Errorf("editing invoice: %w", err) + } + + if err := sm.Invoice.Validate(); err != nil { + return billing.ValidationError{ + Err: err, + } + } + + return nil + }), + ) +} + +func (s *Service) GetStandardInvoiceById(ctx context.Context, input billing.GetStandardInvoiceByIdInput) (billing.StandardInvoice, error) { + invoice, err := s.adapter.GetInvoiceById(ctx, input) + if err != nil { + return billing.StandardInvoice{}, err + } + + return invoice, nil +} + +func (s *Service) ListStandardInvoices(ctx context.Context, input billing.ListStandardInvoicesInput) (billing.ListStandardInvoicesResponse, error) { + invoices, err := s.adapter.ListInvoices(ctx, input) + if err != nil { + return billing.ListStandardInvoicesResponse{}, err + } + + updatedInvoices, err := s.emulateStandardInvoicesGatheringInvoiceFields(ctx, invoices.Items) + if err != nil { + return billing.ListInvoicesResponse{}, fmt.Errorf("error emulating standard invoices gathering invoice fields: %w", err) + } + + invoices.Items = updatedInvoices + + for i := range invoices.Items { + invoiceID := invoices.Items[i].ID + + invoices.Items[i], err = s.resolveWorkflowApps(ctx, invoices.Items[i]) + if err != nil { + return billing.ListInvoicesResponse{}, fmt.Errorf("error resolving workflow apps [%s]: %w", invoiceID, err) + } + + invoices.Items[i], err = s.resolveStatusDetails(ctx, invoices.Items[i]) + if err != nil { + return billing.ListInvoicesResponse{}, fmt.Errorf("error resolving status details for invoice [%s]: %w", invoiceID, err) + } + + if input.Expand.RecalculateGatheringInvoice { + invoices.Items[i], err = s.recalculateGatheringInvoice(ctx, recalculateGatheringInvoiceInput{ + Invoice: invoices.Items[i], + Expand: input.Expand, + }) + if err != nil { + return billing.ListInvoicesResponse{}, fmt.Errorf("error recalculating gathering invoice [%s]: %w", invoiceID, err) + } + } + } + + return invoices, nil +} diff --git a/openmeter/billing/stdinvoice.go b/openmeter/billing/stdinvoice.go index 5b87971e96..e183e358e7 100644 --- a/openmeter/billing/stdinvoice.go +++ b/openmeter/billing/stdinvoice.go @@ -9,12 +9,15 @@ import ( "github.com/samber/lo" "github.com/samber/mo" + "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/openmeter/app" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/sortx" ) type StandardInvoiceStatusCategory string @@ -282,6 +285,13 @@ func (i StandardInvoiceBase) GetInvoiceID() InvoiceID { } } +func (i StandardInvoiceBase) GetCustomerID() customer.CustomerID { + return customer.CustomerID{ + Namespace: i.Namespace, + ID: i.Customer.CustomerID, + } +} + var _ GenericInvoice = (*StandardInvoice)(nil) type StandardInvoice struct { @@ -340,6 +350,33 @@ func (i StandardInvoice) AsInvoice() Invoice { } } +func (i StandardInvoice) GetGenericLines() mo.Option[[]GenericInvoiceLine] { + if !i.Lines.IsPresent() { + return mo.None[[]GenericInvoiceLine]() + } + + return mo.Some(lo.Map(i.Lines.OrEmpty(), func(l *StandardLine, _ int) GenericInvoiceLine { + return &standardInvoiceLineGenericWrapper{StandardLine: l} + })) +} + +func (i *StandardInvoice) SetLines(lines []GenericInvoiceLine) error { + mappedLines, err := slicesx.MapWithErr(lines, func(l GenericInvoiceLine) (*StandardLine, error) { + line, err := l.AsInvoiceLine().AsStandardLine() + if err != nil { + return nil, err + } + + return &line, nil + }) + if err != nil { + return fmt.Errorf("mapping lines: %w", err) + } + + i.Lines = NewStandardInvoiceLines(mappedLines) + return nil +} + func (i *StandardInvoice) MergeValidationIssues(errIn error, reportingComponent ComponentName) error { i.ValidationIssues = lo.Filter(i.ValidationIssues, func(issue ValidationIssue, _ int) bool { return issue.Component != reportingComponent @@ -713,14 +750,14 @@ func (i UpdateInvoiceLinesInternalInput) Validate() error { return nil } -type UpdateInvoiceInput struct { +type UpdateStandardInvoiceInput struct { Invoice InvoiceID EditFn func(*StandardInvoice) error // IncludeDeletedLines signals the update to populate the deleted lines into the lines field, for the edit function IncludeDeletedLines bool } -func (i UpdateInvoiceInput) Validate() error { +func (i UpdateStandardInvoiceInput) Validate() error { if err := i.Invoice.Validate(); err != nil { return fmt.Errorf("id: %w", err) } @@ -879,3 +916,105 @@ func (i UpdateInvoiceFieldsInput) Validate() error { } type RecalculateGatheringInvoicesInput = customer.CustomerID + +type ListStandardInvoicesInput struct { + pagination.Page + + Namespaces []string + IDs []string + // DraftUtil allows to filter invoices which have their draft state expired based on the provided time. + // Invoice is expired if the time defined by Invoice.DraftUntil is in the past compared to ListInvoicesInput.DraftUntil. + DraftUntil *time.Time + + // ExtendedStatuses searches by exact InvoiceStatus + ExtendedStatuses []StandardInvoiceStatus + HasAvailableAction []InvoiceAvailableActionsFilter + ExternalIDs *ListInvoicesExternalIDFilter + + /// DELETE everything below this line + + Customers []string + // Statuses searches by short InvoiceStatus (e.g. draft, issued) + Statuses []string + + Currencies []currencyx.Code + + IssuedAfter *time.Time + IssuedBefore *time.Time + + PeriodStartAfter *time.Time + PeriodStartBefore *time.Time + + // Filter by invoice creation time + CreatedAfter *time.Time + CreatedBefore *time.Time + + IncludeDeleted bool + + // CollectionAt allows to filter invoices which have their collection_at attribute is in the past compared + // to the time provided in CollectionAt parameter. + CollectionAt *time.Time + + Expand InvoiceExpand + + OrderBy api.InvoiceOrderBy + Order sortx.Order +} + +func (i ListStandardInvoicesInput) Validate() error { + var outErr []error + + if i.IssuedAfter != nil && i.IssuedBefore != nil && i.IssuedAfter.After(*i.IssuedBefore) { + outErr = append(outErr, errors.New("issuedAfter must be before issuedBefore")) + } + + if i.CreatedAfter != nil && i.CreatedBefore != nil && i.CreatedAfter.After(*i.CreatedBefore) { + outErr = append(outErr, errors.New("createdAfter must be before createdBefore")) + } + + if i.PeriodStartAfter != nil && i.PeriodStartBefore != nil && i.PeriodStartAfter.After(*i.PeriodStartBefore) { + outErr = append(outErr, errors.New("periodStartAfter must be before periodStartBefore")) + } + + if err := i.Expand.Validate(); err != nil { + outErr = append(outErr, fmt.Errorf("expand: %w", err)) + } + + if i.ExternalIDs != nil { + if err := i.ExternalIDs.Validate(); err != nil { + outErr = append(outErr, fmt.Errorf("external IDs: %w", err)) + } + } + + if len(i.HasAvailableAction) > 0 { + errs := errors.Join( + lo.Map(i.HasAvailableAction, func(action InvoiceAvailableActionsFilter, _ int) error { + return action.Validate() + })..., + ) + if errs != nil { + outErr = append(outErr, errs) + } + } + + return errors.Join(outErr...) +} + +type ListStandardInvoicesResponse = pagination.Result[StandardInvoice] + +type GetStandardInvoiceByIdInput struct { + Invoice InvoiceID + Expand InvoiceExpand +} + +func (i GetStandardInvoiceByIdInput) Validate() error { + if err := i.Invoice.Validate(); err != nil { + return fmt.Errorf("id: %w", err) + } + + if err := i.Expand.Validate(); err != nil { + return fmt.Errorf("expand: %w", err) + } + + return nil +} diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index 1e09d8833d..d1f5ce6eb0 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -176,7 +176,10 @@ func (i LineExternalIDs) Equal(other LineExternalIDs) bool { return i.Invoicing == other.Invoicing } -var _ GenericInvoiceLine = (*standardInvoiceLineGenericWrapper)(nil) +var ( + _ GenericInvoiceLine = (*standardInvoiceLineGenericWrapper)(nil) + _ QuantityAccessor = (*standardInvoiceLineGenericWrapper)(nil) +) // standardInvoiceLineGenericWrapper is a wrapper around a standard line that implements the GenericInvoiceLine interface. // for methods that are present for the specific line type too. @@ -255,6 +258,10 @@ func (i StandardLine) GetChildUniqueReferenceID() *string { return i.ChildUniqueReferenceID } +func (i *StandardLine) SetChildUniqueReferenceID(id *string) { + i.ChildUniqueReferenceID = id +} + func (i StandardLine) AsInvoiceLine() InvoiceLine { return InvoiceLine{ t: InvoiceLineTypeStandard, @@ -262,6 +269,14 @@ func (i StandardLine) AsInvoiceLine() InvoiceLine { } } +func (i StandardLine) GetQuantity() *alpacadecimal.Decimal { + if i.UsageBased == nil { + return nil + } + + return i.UsageBased.Quantity +} + // ToGatheringLineBase converts the standard line to a gathering line base. // This is temporary until the full gathering invoice functionality is split. func (i StandardLine) ToGatheringLineBase() (GatheringLineBase, error) { @@ -423,6 +438,14 @@ func (i StandardLine) GetInvoiceAt() time.Time { return i.InvoiceAt } +func (i StandardLine) GetSubscriptionReference() *SubscriptionReference { + if i.Subscription == nil { + return nil + } + + return i.Subscription.Clone() +} + type cloneOptions struct { skipDBState bool skipChildren bool diff --git a/openmeter/billing/validators/customer/customer.go b/openmeter/billing/validators/customer/customer.go index 754925ee91..ac55406864 100644 --- a/openmeter/billing/validators/customer/customer.go +++ b/openmeter/billing/validators/customer/customer.go @@ -68,7 +68,7 @@ func (v *Validator) ValidateDeleteCustomer(ctx context.Context, input customer.D } } - gatheringInvoices, err := v.billingService.ListInvoices(ctx, billing.ListInvoicesInput{ + invoices, err := v.billingService.ListInvoices(ctx, billing.ListInvoicesInput{ Namespaces: []string{input.Namespace}, Customers: []string{input.ID}, }) @@ -76,16 +76,30 @@ func (v *Validator) ValidateDeleteCustomer(ctx context.Context, input customer.D return err } - errs := make([]error, 0, len(gatheringInvoices.Items)) - for _, inv := range gatheringInvoices.Items { - if inv.Status == billing.StandardInvoiceStatusGathering { - errs = append(errs, fmt.Errorf("invoice %s is still in gathering state", inv.ID)) + errs := make([]error, 0, len(invoices.Items)) + for _, inv := range invoices.Items { + if inv.Type() == billing.InvoiceTypeGathering { + gatheringInvoice, err := inv.AsGatheringInvoice() + if err != nil { + return err + } + + if gatheringInvoice.DeletedAt != nil { + continue + } + errs = append(errs, fmt.Errorf("invoice %s is still in gathering state", gatheringInvoice.ID)) continue } - if !inv.Status.IsFinal() { - errs = append(errs, fmt.Errorf("invoice %s is not in final state, please either delete the invoice or mark it uncollectible", inv.ID)) + stdInvoice, err := inv.AsStandardInvoice() + if err != nil { + return err + } + + if !stdInvoice.Status.IsFinal() { + errs = append(errs, fmt.Errorf("invoice %s is not in final state, please either delete the invoice or mark it uncollectible", stdInvoice.ID)) + continue } } diff --git a/openmeter/billing/worker/advance/advance.go b/openmeter/billing/worker/advance/advance.go index 9f2c6dece3..d305832415 100644 --- a/openmeter/billing/worker/advance/advance.go +++ b/openmeter/billing/worker/advance/advance.go @@ -14,7 +14,7 @@ import ( ) type AutoAdvancer struct { - invoice billing.InvoiceService + invoice billing.StandardInvoiceService logger *slog.Logger } @@ -77,7 +77,7 @@ func (a *AutoAdvancer) All(ctx context.Context, namespaces []string, batchSize i // ListInvoicesPendingAutoAdvance lists invoices that are due to be auto-advanced func (a *AutoAdvancer) ListInvoicesPendingAutoAdvance(ctx context.Context, namespaces []string, ids []string) ([]billing.StandardInvoice, error) { - resp, err := a.invoice.ListInvoices(ctx, billing.ListInvoicesInput{ + resp, err := a.invoice.ListStandardInvoices(ctx, billing.ListStandardInvoicesInput{ ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusDraftWaitingAutoApproval}, DraftUntil: lo.ToPtr(time.Now()), Namespaces: namespaces, @@ -92,7 +92,7 @@ func (a *AutoAdvancer) ListInvoicesPendingAutoAdvance(ctx context.Context, names // ListInvoicesPendingCollection lists invoices that are due to be collected func (a *AutoAdvancer) ListInvoicesPendingCollection(ctx context.Context, namespaces []string, ids []string) ([]billing.StandardInvoice, error) { - resp, err := a.invoice.ListInvoices(ctx, billing.ListInvoicesInput{ + resp, err := a.invoice.ListStandardInvoices(ctx, billing.ListStandardInvoicesInput{ ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusDraftWaitingForCollection}, CollectionAt: lo.ToPtr(time.Now()), Namespaces: namespaces, @@ -107,7 +107,7 @@ func (a *AutoAdvancer) ListInvoicesPendingCollection(ctx context.Context, namesp // ListStuckInvoicesNeedingAdvance lists invoices that are stuck in some advanceable state (this is a fail-safe mechanism) func (a *AutoAdvancer) ListStuckInvoicesNeedingAdvance(ctx context.Context, namespaces []string, ids []string) ([]billing.StandardInvoice, error) { - resp, err := a.invoice.ListInvoices(ctx, billing.ListInvoicesInput{ + resp, err := a.invoice.ListStandardInvoices(ctx, billing.ListStandardInvoicesInput{ HasAvailableAction: []billing.InvoiceAvailableActionsFilter{billing.InvoiceAvailableActionsFilterAdvance}, Namespaces: namespaces, IDs: ids, @@ -148,7 +148,7 @@ func (a *AutoAdvancer) AdvanceInvoice(ctx context.Context, id billing.InvoiceID) // ErrInvoiceCannotAdvance is returned when the invoice cannot be advanced due to state machine settings // thus we can safely ignore this error, we will retry if errors.Is(err, billing.ErrInvoiceCannotAdvance) { - invoice, err := a.invoice.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + invoice, err := a.invoice.GetStandardInvoiceById(ctx, billing.GetStandardInvoiceByIdInput{ Invoice: id, }) if err != nil { diff --git a/openmeter/billing/worker/collect/collect.go b/openmeter/billing/worker/collect/collect.go index ffa10d1117..0c868207c4 100644 --- a/openmeter/billing/worker/collect/collect.go +++ b/openmeter/billing/worker/collect/collect.go @@ -39,17 +39,16 @@ func (i ListCollectableInvoicesInput) Validate() error { return errors.Join(errs...) } -func (a *InvoiceCollector) ListCollectableInvoices(ctx context.Context, params ListCollectableInvoicesInput) ([]billing.StandardInvoice, error) { +func (a *InvoiceCollector) ListCollectableInvoices(ctx context.Context, params ListCollectableInvoicesInput) ([]billing.GatheringInvoice, error) { if err := params.Validate(); err != nil { return nil, fmt.Errorf("invalid input: %w", err) } - resp, err := a.billing.ListInvoices(ctx, billing.ListInvoicesInput{ - Namespaces: params.Namespaces, - IDs: params.InvoiceIDs, - Customers: params.Customers, - CollectionAt: lo.ToPtr(params.CollectionAt), - ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusGathering}, + resp, err := a.billing.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ + Namespaces: params.Namespaces, + IDs: params.InvoiceIDs, + Customers: params.Customers, + NextCollectionAtBeforeOrEqual: lo.ToPtr(params.CollectionAt), }) if err != nil { return nil, fmt.Errorf("failed to list collectable invoices: %w", err) @@ -132,10 +131,10 @@ func (a *InvoiceCollector) All(ctx context.Context, namespaces []string, custome return nil } - customerIDs := lo.Map(invoices, func(i billing.StandardInvoice, _ int) customer.CustomerID { + customerIDs := lo.Map(invoices, func(i billing.GatheringInvoice, _ int) customer.CustomerID { return customer.CustomerID{ Namespace: i.Namespace, - ID: i.Customer.CustomerID, + ID: i.CustomerID, } }) diff --git a/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go b/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go index 75c912a3b4..6b183aa9d0 100644 --- a/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go +++ b/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go @@ -176,8 +176,8 @@ func (u *InvoiceUpdater) provisionUpcomingLines(ctx context.Context, customerID } func (u *InvoiceUpdater) updateMutableStandardInvoice(ctx context.Context, invoice billing.StandardInvoice, linePatches invoicePatches) error { - updatedInvoice, err := u.billingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + updatedInvoice, err := u.billingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), IncludeDeletedLines: true, EditFn: func(invoice *billing.StandardInvoice) error { // Let's delete lines if needed diff --git a/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go b/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go index bbdffdeb9c..bbbd72d6b1 100644 --- a/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go @@ -126,19 +126,19 @@ func (s *SuiteBase) AfterTest(ctx context.Context, suiteName, testName string) { s.Service.featureFlags = FeatureFlags{} } -func (s *SuiteBase) gatheringInvoice(ctx context.Context, namespace string, customerID string) billing.StandardInvoice { +func (s *SuiteBase) gatheringInvoice(ctx context.Context, namespace string, customerID string) billing.GatheringInvoice { s.T().Helper() - invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + invoices, err := s.BillingService.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ Namespaces: []string{namespace}, Customers: []string{customerID}, Page: pagination.Page{ PageSize: 10, PageNumber: 1, }, - Expand: billing.InvoiceExpandAll, - Statuses: []string{ - string(billing.StandardInvoiceStatusGathering), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandAvailableActions, }, }) @@ -150,17 +150,14 @@ func (s *SuiteBase) gatheringInvoice(ctx context.Context, namespace string, cust func (s *SuiteBase) expectNoGatheringInvoice(ctx context.Context, namespace string, customerID string) { s.T().Helper() - invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + invoices, err := s.BillingService.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ Namespaces: []string{namespace}, Customers: []string{customerID}, Page: pagination.Page{ PageSize: 10, PageNumber: 1, }, - Expand: billing.InvoiceExpandAll, - Statuses: []string{ - string(billing.StandardInvoiceStatusGathering), - }, + Expand: billing.GatheringInvoiceExpands{}, }) s.NoError(err) @@ -177,12 +174,12 @@ func (s *SuiteBase) enableProrating() { s.Service.featureFlags.EnableFlatFeeInArrearsProrating = true } -func (s *SuiteBase) getLineByChildID(invoice billing.StandardInvoice, childID string) *billing.StandardLine { +func (s *SuiteBase) getGatheringLineByChildID(invoice billing.GatheringInvoice, childID string) *billing.GatheringLine { s.T().Helper() - for _, line := range invoice.Lines.OrEmpty() { + for idx, line := range invoice.Lines.OrEmpty() { if line.ChildUniqueReferenceID != nil && *line.ChildUniqueReferenceID == childID { - return line + return &invoice.Lines.OrEmpty()[idx] } } @@ -191,11 +188,25 @@ func (s *SuiteBase) getLineByChildID(invoice billing.StandardInvoice, childID st return nil } -func (s *SuiteBase) expectNoLineWithChildID(invoice billing.StandardInvoice, childID string) { +func (s *SuiteBase) getStandardLineByChildID(invoice billing.StandardInvoice, childID string) *billing.StandardLine { s.T().Helper() for _, line := range invoice.Lines.OrEmpty() { if line.ChildUniqueReferenceID != nil && *line.ChildUniqueReferenceID == childID { + return line + } + } + + s.Failf("line not found", "line with child id %s not found", childID) + + return nil +} + +func (s *SuiteBase) expectNoLineWithChildID(invoice billing.GenericInvoiceReader, childID string) { + s.T().Helper() + + for _, line := range invoice.GetGenericLines().OrEmpty() { + if line.GetChildUniqueReferenceID() != nil && *line.GetChildUniqueReferenceID() == childID { s.Failf("line found", "line with child id %s found", childID) } } @@ -224,16 +235,19 @@ type expectedLine struct { Price mo.Option[*productcatalog.Price] Periods []billing.Period InvoiceAt mo.Option[[]time.Time] - AdditionalChecks func(line *billing.StandardLine) + AdditionalChecks func(line billing.GenericInvoiceLine) } -func (s *SuiteBase) expectLines(invoice billing.StandardInvoice, subscriptionID string, expectedLines []expectedLine) { +func (s *SuiteBase) expectLines(invoice billing.GenericInvoiceReader, subscriptionID string, expectedLines []expectedLine) { s.T().Helper() - lines := invoice.Lines.OrEmpty() + lines := invoice.GetGenericLines() + if lines.IsAbsent() { + s.Failf("lines not found", "lines not found for invoice %s", invoice.GetID()) + } - existingLineChildIDs := lo.Map(lines, func(line *billing.StandardLine, _ int) string { - return lo.FromPtrOr(line.ChildUniqueReferenceID, line.ID) + existingLineChildIDs := lo.Map(lines.OrEmpty(), func(line billing.GenericInvoiceLine, _ int) string { + return lo.FromPtrOr(line.GetChildUniqueReferenceID(), line.GetID()) }) expectedLineIds := lo.Flatten(lo.Map(expectedLines, func(expectedLine expectedLine, _ int) []string { @@ -245,31 +259,41 @@ func (s *SuiteBase) expectLines(invoice billing.StandardInvoice, subscriptionID for _, expectedLine := range expectedLines { childIDs := expectedLine.Matcher.ChildIDs(subscriptionID) for idx, childID := range childIDs { - line, found := lo.Find(lines, func(line *billing.StandardLine) bool { - return lo.FromPtrOr(line.ChildUniqueReferenceID, line.ID) == childID + line, found := lo.Find(lines.OrEmpty(), func(line billing.GenericInvoiceLine) bool { + return lo.FromPtrOr(line.GetChildUniqueReferenceID(), line.GetID()) == childID }) s.Truef(found, "line not found with child id %s", childID) s.NotNil(line) if expectedLine.Qty.IsPresent() { - if line.UsageBased == nil { - s.Failf("usage based line not found", "line not found with child id %s", childID) - } else if line.UsageBased.Quantity == nil { - s.Failf("usage based line quantity not found", "line not found with child id %s", childID) + lineQuantityAccessor, ok := line.(billing.QuantityAccessor) + if !ok { + s.Failf("line is not a quantity accessor", "line is not a quantity accessor with child id %s", childID) + } + + lineQuantity := lineQuantityAccessor.GetQuantity() + if lineQuantity == nil { + s.Failf("line quantity not found", "line quantity not found with child id %s", childID) } else { - s.Equal(expectedLine.Qty.OrEmpty(), line.UsageBased.Quantity.InexactFloat64(), "%s: quantity", childID) + s.Equal(expectedLine.Qty.OrEmpty(), lineQuantity.InexactFloat64(), "%s: quantity", childID) } } if expectedLine.Price.IsPresent() { - s.Equal(*expectedLine.Price.OrEmpty(), *line.UsageBased.Price, "%s: price", childID) + s.Equal(*expectedLine.Price.OrEmpty(), *line.GetPrice(), "%s: price", childID) } - s.Equal(expectedLine.Periods[idx].Start, line.Period.Start, "%s: period start", childID) - s.Equal(expectedLine.Periods[idx].End, line.Period.End, "%s: period end", childID) + s.Equal(expectedLine.Periods[idx].Start, line.GetServicePeriod().From, "%s: period start", childID) + s.Equal(expectedLine.Periods[idx].End, line.GetServicePeriod().To, "%s: period end", childID) if expectedLine.InvoiceAt.IsPresent() { - s.Equal(expectedLine.InvoiceAt.OrEmpty()[idx], line.InvoiceAt, "%s: invoice at", childID) + invoiceAtAccessor, ok := line.(billing.InvoiceAtAccessor) + if !ok { + s.Failf("line is not a invoice at accessor", "line is not a invoice at accessor with child id %s", childID) + } + + invoiceAt := invoiceAtAccessor.GetInvoiceAt() + s.Equal(expectedLine.InvoiceAt.OrEmpty()[idx], invoiceAt, "%s: invoice at", childID) } if expectedLine.AdditionalChecks != nil { @@ -423,12 +447,47 @@ func (s *SuiteBase) generatePeriods(startStr, endStr string, cadenceStr string, // populateChildIDsFromParents copies over the child ID from the parent line, if it's not already set // as line splitting doesn't set the child ID on child lines to prevent conflicts if multiple split lines // end up on a single invoice. -func (s *SuiteBase) populateChildIDsFromParents(invoice *billing.StandardInvoice) { - for _, line := range invoice.Lines.OrEmpty() { - if line.ChildUniqueReferenceID == nil && line.SplitLineGroupID != nil { - line.ChildUniqueReferenceID = line.SplitLineHierarchy.Group.UniqueReferenceID +func (s *SuiteBase) populateChildIDsFromParents(invoice billing.GenericInvoice) { + genericLinesOption := invoice.GetGenericLines() + if genericLinesOption.IsAbsent() { + s.Failf("lines not found", "lines not found for invoice %s", invoice.GetID()) + } + + genericLines := genericLinesOption.OrEmpty() + + for idx, line := range genericLines { + if line.GetChildUniqueReferenceID() == nil && line.GetSplitLineGroupID() != nil { + invoiceLine := line.AsInvoiceLine() + switch invoiceLine.Type() { + case billing.InvoiceLineTypeStandard: + stdInvoiceLine, err := invoiceLine.AsStandardLine() + s.NoError(err) + + line.SetChildUniqueReferenceID(stdInvoiceLine.SplitLineHierarchy.Group.UniqueReferenceID) + case billing.InvoiceLineTypeGathering: + splitLineGroupID := line.GetSplitLineGroupID() + if splitLineGroupID == nil { + s.Failf("split line group id not found", "split line group id not found for line %s", line.GetID()) + return + } + + splitLineGroup, err := s.BillingAdapter.GetSplitLineGroup(s.T().Context(), billing.GetSplitLineGroupInput{ + Namespace: s.Namespace, + ID: *splitLineGroupID, + }) + s.NoError(err) + + line.SetChildUniqueReferenceID(splitLineGroup.Group.UniqueReferenceID) + default: + s.Failf("unexpected line type", "unexpected line type %s for line %s", invoiceLine.Type(), line.GetID()) + } } + + genericLines[idx] = line } + + err := invoice.SetLines(genericLines) + s.NoError(err) } // helpers diff --git a/openmeter/billing/worker/subscriptionsync/service/sync.go b/openmeter/billing/worker/subscriptionsync/service/sync.go index 6f8587876a..e4fdd43ffa 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync.go @@ -30,7 +30,7 @@ const ( SubscriptionSyncComponentName billing.ComponentName = "subscription-sync" ) -type InvoiceByID map[string]billing.StandardInvoice +type InvoiceByID map[string]billing.Invoice func (i InvoiceByID) IsGatheringInvoice(invoiceID string) bool { invoice, ok := i[invoiceID] @@ -39,7 +39,7 @@ func (i InvoiceByID) IsGatheringInvoice(invoiceID string) bool { return true } - return invoice.Status == billing.StandardInvoiceStatusGathering + return invoice.Type() == billing.InvoiceTypeGathering } func (s *Service) invoicePendingLines(ctx context.Context, customer customer.CustomerID) error { @@ -151,9 +151,15 @@ func (s *Service) SynchronizeSubscription(ctx context.Context, subs subscription return fmt.Errorf("listing invoices: %w", err) } - invoiceByID := lo.SliceToMap(invoices.Items, func(i billing.StandardInvoice) (string, billing.StandardInvoice) { - return i.ID, i - }) + invoiceByID := make(InvoiceByID, len(invoices.Items)) + for _, invoice := range invoices.Items { + genericInvoice, err := invoice.AsGenericInvoice() + if err != nil { + return fmt.Errorf("converting invoice to generic invoice: %w", err) + } + + invoiceByID[genericInvoice.GetID()] = invoice + } // Calculate per line patches linesDiff, err := s.compareSubscriptionWithExistingLines(ctx, subs, asOf) diff --git a/openmeter/billing/worker/subscriptionsync/service/sync_test.go b/openmeter/billing/worker/subscriptionsync/service/sync_test.go index 95ce2dc1aa..8df6e682a4 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync_test.go @@ -221,15 +221,14 @@ func (s *SubscriptionHandlerTestSuite) TestSubscriptionHappyPath() { gatheringInvoice := s.gatheringInvoice(ctx, namespace, s.Customer.ID) s.NoError(err) - gatheringInvoiceID = gatheringInvoice.InvoiceID() + gatheringInvoiceID = gatheringInvoice.GetInvoiceID() s.DebugDumpInvoice("gathering invoice - 2nd update", gatheringInvoice) gatheringLine := gatheringInvoice.Lines.OrEmpty()[0] - s.Equal(invoiceUpdatedAt, gatheringInvoice.UpdatedAt) - s.Equal(billing.StandardInvoiceStatusGathering, gatheringInvoice.Status) - s.Equal(line.UpdatedAt, gatheringLine.UpdatedAt) + s.Equal(invoiceUpdatedAt, gatheringInvoice.GetUpdatedAt()) + s.Equal(line.GetUpdatedAt(), gatheringLine.GetUpdatedAt()) }) s.NoError(gatheringInvoiceID.Validate()) @@ -2122,9 +2121,6 @@ func (s *SubscriptionHandlerTestSuite) TestAlignedSubscriptionProgressiveBilling End: startTime.AddDate(0, 0, 1), }, }, - InvoiceAt: mo.Some([]time.Time{ - startTime.AddDate(0, 0, 1), - }), }, }) @@ -2552,7 +2548,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedGatheringUpdateDraftInvoice End: s.mustParseTime("2024-02-01T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-02-01T00:00:00Z")}), }, }) @@ -2768,7 +2763,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedGatheringUpdateIssuedInvoic End: s.mustParseTime("2024-02-01T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-02-01T00:00:00Z")}), }, }) @@ -2832,7 +2826,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedGatheringUpdateIssuedInvoic End: s.mustParseTime("2024-02-01T00:00:00Z"), // This is not updated, which is what we want }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-02-01T00:00:00Z")}), }, }) @@ -2923,7 +2916,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedUpdateWithLineSplits() { End: s.mustParseTime("2024-01-15T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-01-15T00:00:00Z")}), }, }) @@ -2958,7 +2950,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedUpdateWithLineSplits() { End: s.mustParseTime("2024-01-18T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-01-18T00:00:00Z")}), }, }) @@ -3089,7 +3080,6 @@ func (s *SubscriptionHandlerTestSuite) TestUsageBasedUpdateWithLineSplits() { End: s.mustParseTime("2024-01-15T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-01-15T00:00:00Z")}), }, }) @@ -3142,23 +3132,23 @@ func (s *SubscriptionHandlerTestSuite) TestGatheringManualEditSync() { gatheringInvoice := s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID) s.DebugDumpInvoice("gathering invoice", gatheringInvoice) - var updatedLine *billing.StandardLine - editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: gatheringInvoice.InvoiceID(), - EditFn: func(invoice *billing.StandardInvoice) error { - line := s.getLineByChildID(*invoice, fmt.Sprintf("%s/first-phase/in-advance/v[0]/period[0]", subsView.Subscription.ID)) + var updatedLine billing.GatheringLine + err := s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + EditFn: func(invoice *billing.GatheringInvoice) error { + line := s.getGatheringLineByChildID(*invoice, fmt.Sprintf("%s/first-phase/in-advance/v[0]/period[0]", subsView.Subscription.ID)) - price, err := line.UsageBased.Price.AsFlat() + price, err := line.Price.AsFlat() s.NoError(err) price.PaymentTerm = productcatalog.InArrearsPaymentTerm - line.UsageBased.Price = productcatalog.NewPriceFrom(price) + line.Price = *productcatalog.NewPriceFrom(price) - line.Period = billing.Period{ - Start: line.Period.Start.Add(time.Hour), - End: line.Period.End.Add(time.Hour), + line.ServicePeriod = timeutil.ClosedPeriod{ + From: line.ServicePeriod.From.Add(time.Hour), + To: line.ServicePeriod.To.Add(time.Hour), } - line.InvoiceAt = line.Period.End + line.InvoiceAt = line.ServicePeriod.To line.ManagedBy = billing.ManuallyManagedLine updatedLine, err = line.Clone() @@ -3166,6 +3156,15 @@ func (s *SubscriptionHandlerTestSuite) TestGatheringManualEditSync() { return nil }, }) + s.NoError(err) + + editedInvoice, err := s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) s.NoError(err) s.DebugDumpInvoice("edited invoice", editedInvoice) @@ -3176,8 +3175,8 @@ func (s *SubscriptionHandlerTestSuite) TestGatheringManualEditSync() { s.DebugDumpInvoice("gathering invoice - after sync", gatheringInvoice) // Then the line should not be updated - invoiceLine := s.getLineByChildID(gatheringInvoice, *updatedLine.ChildUniqueReferenceID) - s.True(invoiceLine.StandardLineBase.Equal(updatedLine.StandardLineBase), "line should not be updated") + invoiceLine := s.getGatheringLineByChildID(gatheringInvoice, *updatedLine.ChildUniqueReferenceID) + s.True(invoiceLine.GatheringLineBase.Equal(updatedLine.GatheringLineBase), "line should not be updated") } func (s *SubscriptionHandlerTestSuite) TestSplitLineManualEditSync() { @@ -3231,7 +3230,7 @@ func (s *SubscriptionHandlerTestSuite) TestSplitLineManualEditSync() { s.DebugDumpInvoice("draft invoice", draftInvoice) var updatedLine *billing.StandardLine - editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + editedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: draftInvoice.InvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { lines := invoice.Lines.OrEmpty() @@ -3321,14 +3320,14 @@ func (s *SubscriptionHandlerTestSuite) TestGatheringManualDeleteSync() { gatheringInvoice := s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID) s.DebugDumpInvoice("gathering invoice", gatheringInvoice) - var updatedLine *billing.StandardLine + var updatedLine billing.GatheringLine childUniqueReferenceID := fmt.Sprintf("%s/first-phase/in-advance/v[0]/period[0]", subsView.Subscription.ID) - editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: gatheringInvoice.InvoiceID(), - EditFn: func(invoice *billing.StandardInvoice) error { - line := s.getLineByChildID(*invoice, childUniqueReferenceID) + err := s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + EditFn: func(invoice *billing.GatheringInvoice) error { + line := s.getGatheringLineByChildID(*invoice, childUniqueReferenceID) line.DeletedAt = lo.ToPtr(clock.Now()) line.ManagedBy = billing.ManuallyManagedLine @@ -3340,7 +3339,16 @@ func (s *SubscriptionHandlerTestSuite) TestGatheringManualDeleteSync() { }) s.NoError(err) - updatedLineFromEditedInvoice := s.getLineByChildID(editedInvoice, childUniqueReferenceID) + editedInvoice, err := s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) + s.NoError(err) + + updatedLineFromEditedInvoice := s.getGatheringLineByChildID(editedInvoice, childUniqueReferenceID) s.NotNil(updatedLineFromEditedInvoice.DeletedAt) s.Equal(billing.ManuallyManagedLine, updatedLineFromEditedInvoice.ManagedBy) @@ -3419,10 +3427,10 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLines() { s.Equal(gatheringLineReferenceID, *gatheringLines[0].ChildUniqueReferenceID) // Now let's manually mark the lines as sync ignored - _, err = s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + _, err = s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: draftInvoice.InvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { - line := s.getLineByChildID(*invoice, draftLineReferenceID) + line := s.getStandardLineByChildID(*invoice, draftLineReferenceID) line.Annotations = models.Annotations{ billing.AnnotationSubscriptionSyncIgnore: true, @@ -3434,12 +3442,12 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLines() { }) s.NoError(err) - var gatheringInvoiceIgnoredLine *billing.StandardLine + var gatheringInvoiceIgnoredLine billing.GatheringLine - gatheringInvoice, err = s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: gatheringInvoice.InvoiceID(), - EditFn: func(invoice *billing.StandardInvoice) error { - line := s.getLineByChildID(*invoice, gatheringLineReferenceID) + err = s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + EditFn: func(invoice *billing.GatheringInvoice) error { + line := s.getGatheringLineByChildID(*invoice, gatheringLineReferenceID) line.Annotations = models.Annotations{ billing.AnnotationSubscriptionSyncIgnore: true, @@ -3516,7 +3524,7 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLines() { gatheringInvoice = s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID) s.DebugDumpInvoice("gathering invoice - after sync", gatheringInvoice) - gatheringInvoiceIgnoredLineAfterSync := s.getLineByChildID(gatheringInvoice, *gatheringInvoiceIgnoredLine.ChildUniqueReferenceID) + gatheringInvoiceIgnoredLineAfterSync := s.getGatheringLineByChildID(gatheringInvoice, *gatheringInvoiceIgnoredLine.ChildUniqueReferenceID) s.Equal(lo.Must(gatheringInvoiceIgnoredLine.RemoveMetaForCompare()), lo.Must(gatheringInvoiceIgnoredLineAfterSync.RemoveMetaForCompare())) // But the non-marked line should be deleted @@ -3532,10 +3540,10 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLines() { s.Len(updatedGartheringLines, 4) newLineReferenceID := fmt.Sprintf("%s/first-phase/in-advance/v[1]/period[0]", subsView.Subscription.ID) - updatedLine := s.getLineByChildID(gatheringInvoice, newLineReferenceID) + updatedLine := s.getGatheringLineByChildID(gatheringInvoice, newLineReferenceID) s.NotNil(updatedLine) - price, err := updatedLine.UsageBased.Price.AsFlat() + price, err := updatedLine.Price.AsFlat() s.NoError(err) s.Equal(alpacadecimal.NewFromFloat(10), price.Amount) } @@ -3596,10 +3604,10 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLinesWhenPeriod unMarkedLineReferenceID := fmt.Sprintf("%s/first-phase/non-marked/v[0]/period[0]", subsView.Subscription.ID) // Now let's manually mark the lines as sync ignored - gatheringInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: gatheringInvoice.InvoiceID(), - EditFn: func(invoice *billing.StandardInvoice) error { - line := s.getLineByChildID(*invoice, markedLineReferenceID) + err := s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: gatheringInvoice.GetInvoiceID(), + EditFn: func(invoice *billing.GatheringInvoice) error { + line := s.getGatheringLineByChildID(*invoice, markedLineReferenceID) line.Annotations = models.Annotations{ billing.AnnotationSubscriptionSyncIgnore: true, @@ -3627,18 +3635,18 @@ func (s *SubscriptionHandlerTestSuite) TestManualIgnoringOfSyncedLinesWhenPeriod s.DebugDumpInvoice("gathering invoice - after sync", gatheringInvoice) // And assert that everything works as expected - markedLine := s.getLineByChildID(gatheringInvoice, markedLineReferenceID) + markedLine := s.getGatheringLineByChildID(gatheringInvoice, markedLineReferenceID) s.NotNil(markedLine) - s.Equal(markedLine.Period, billing.Period{ - Start: s.mustParseTime("2024-01-01T00:00:00Z"), - End: s.mustParseTime("2024-04-01T00:00:00Z"), // period wasn't updated + s.Equal(markedLine.ServicePeriod, timeutil.ClosedPeriod{ + From: s.mustParseTime("2024-01-01T00:00:00Z"), + To: s.mustParseTime("2024-04-01T00:00:00Z"), // period wasn't updated }) - unmarkedLine := s.getLineByChildID(gatheringInvoice, unMarkedLineReferenceID) + unmarkedLine := s.getGatheringLineByChildID(gatheringInvoice, unMarkedLineReferenceID) s.NotNil(unmarkedLine) - s.Equal(unmarkedLine.Period, billing.Period{ - Start: s.mustParseTime("2024-01-01T00:00:00Z"), - End: s.mustParseTime("2024-02-01T00:00:00Z"), // period was updated + s.Equal(unmarkedLine.ServicePeriod, timeutil.ClosedPeriod{ + From: s.mustParseTime("2024-01-01T00:00:00Z"), + To: s.mustParseTime("2024-02-01T00:00:00Z"), // period was updated }) } @@ -3693,7 +3701,7 @@ func (s *SubscriptionHandlerTestSuite) TestSplitLineManualDeleteSync() { s.DebugDumpInvoice("gathering invoice - after invoicing", s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID)) var updatedLine *billing.StandardLine - editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + editedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: draftInvoice.InvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { lines := invoice.Lines.OrEmpty() @@ -3910,7 +3918,6 @@ func (s *SubscriptionHandlerTestSuite) TestInAdvanceInstantBillingOnSubscription End: s.mustParseTime("2024-02-01T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-01-01T00:00:00Z")}), }, }) } @@ -3963,9 +3970,13 @@ func (s *SubscriptionHandlerTestSuite) TestInAdvanceInstantBillingOnSubscription s.NoError(s.Service.SynchronizeSubscriptionAndInvoiceCustomer(ctx, subsView, clock.Now())) - invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ - Customers: []string{s.Customer.ID}, - Expand: billing.InvoiceExpandAll, + invoices, err := s.BillingService.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ + Namespaces: []string{s.Namespace}, + Customers: []string{s.Customer.ID}, + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, }) s.NoError(err) s.Len(invoices.Items, 1) @@ -4052,12 +4063,21 @@ func (s *SubscriptionHandlerTestSuite) TestDiscountSynchronization() { s.NoError(err) s.Len(invoices.Items, 2) - var gatheringInvoice *billing.StandardInvoice + var gatheringInvoice *billing.GatheringInvoice var instantInvoice *billing.StandardInvoice for _, invoice := range invoices.Items { if invoice.Status == billing.StandardInvoiceStatusGathering { - gatheringInvoice = &invoice + // TODO: let's use generic listing call once ready + fetchedGatheringInvoice, err := s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: invoice.GetInvoiceID(), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) + s.NoError(err) + gatheringInvoice = &fetchedGatheringInvoice continue } @@ -4128,7 +4148,6 @@ func (s *SubscriptionHandlerTestSuite) TestDiscountSynchronization() { End: s.mustParseTime("2024-02-01T00:00:00Z"), }, }, - InvoiceAt: mo.Some([]time.Time{s.mustParseTime("2024-01-01T00:00:00Z")}), }, }) @@ -4435,19 +4454,19 @@ func (s *SubscriptionHandlerTestSuite) TestSynchronizeSubscriptionPeriodAlgorith invoice := s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID) s.DebugDumpInvoice("gathering invoice", invoice) - invoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), - EditFn: func(invoice *billing.StandardInvoice) error { + err := s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: invoice.GetInvoiceID(), + EditFn: func(invoice *billing.GatheringInvoice) error { line := invoice.Lines.OrEmpty()[0] // simulate some faulty behavior (the old algo would have set the end to 03-03, but this way we can test this with both the old and new alog) - line.Period.Start = s.mustParseTime("2025-01-31T00:00:00Z") - line.Period.End = s.mustParseTime("2025-03-02T00:00:00Z") + line.ServicePeriod.From = s.mustParseTime("2025-01-31T00:00:00Z") + line.ServicePeriod.To = s.mustParseTime("2025-03-02T00:00:00Z") line.Annotations = models.Annotations{ billing.AnnotationSubscriptionSyncIgnore: true, billing.AnnotationSubscriptionSyncForceContinuousLines: true, } - invoice.Lines = billing.NewStandardInvoiceLines([]*billing.StandardLine{ + invoice.Lines = billing.NewGatheringInvoiceLines([]billing.GatheringLine{ line, }) return nil @@ -4455,6 +4474,15 @@ func (s *SubscriptionHandlerTestSuite) TestSynchronizeSubscriptionPeriodAlgorith }) s.NoError(err) + invoice, err = s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: invoice.GetInvoiceID(), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) + s.NoError(err) + s.DebugDumpInvoice("gathering invoice - updated", invoice) s.expectLines(invoice, subsView.Subscription.ID, []expectedLine{ { diff --git a/openmeter/billing/worker/subscriptionsync/service/syncbillinganchor_test.go b/openmeter/billing/worker/subscriptionsync/service/syncbillinganchor_test.go index 725d782587..73bb41a212 100644 --- a/openmeter/billing/worker/subscriptionsync/service/syncbillinganchor_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/syncbillinganchor_test.go @@ -183,9 +183,9 @@ func (s *BillingAnchorTestSuite) TestBillingAnchorSinglePhase() { InvoiceAt: mo.Some([]time.Time{ testutils.GetRFC3339Time(s.T(), "2025-07-31T15:00:00Z"), }), - AdditionalChecks: func(line *billing.StandardLine) { - s.Equal(testutils.GetRFC3339Time(s.T(), "2025-07-10T15:00:00Z"), line.Subscription.BillingPeriod.From) - s.Equal(testutils.GetRFC3339Time(s.T(), "2025-07-31T15:00:00Z"), line.Subscription.BillingPeriod.To) + AdditionalChecks: func(line billing.GenericInvoiceLine) { + s.Equal(testutils.GetRFC3339Time(s.T(), "2025-07-10T15:00:00Z"), line.GetSubscriptionReference().BillingPeriod.From) + s.Equal(testutils.GetRFC3339Time(s.T(), "2025-07-31T15:00:00Z"), line.GetSubscriptionReference().BillingPeriod.To) }, }, { diff --git a/openmeter/productcatalog/price.go b/openmeter/productcatalog/price.go index 3eb149b636..bb3fc0b682 100644 --- a/openmeter/productcatalog/price.go +++ b/openmeter/productcatalog/price.go @@ -121,7 +121,7 @@ func (p *Price) Clone() *Price { return clone } -func (p *Price) MarshalJSON() ([]byte, error) { +func (p Price) MarshalJSON() ([]byte, error) { var b []byte var err error var serde interface{} diff --git a/openmeter/server/server_test.go b/openmeter/server/server_test.go index a101228e6a..4ccf4b8ae4 100644 --- a/openmeter/server/server_test.go +++ b/openmeter/server/server_test.go @@ -1508,7 +1508,11 @@ func (n NoopBillingService) DeleteInvoice(ctx context.Context, input billing.Del return billing.StandardInvoice{}, nil } -func (n NoopBillingService) UpdateInvoice(ctx context.Context, input billing.UpdateInvoiceInput) (billing.StandardInvoice, error) { +func (n NoopBillingService) UpdateInvoice(ctx context.Context, input billing.UpdateInvoiceInput) (billing.Invoice, error) { + return billing.Invoice{}, nil +} + +func (n NoopBillingService) UpdateStandardInvoice(ctx context.Context, input billing.UpdateStandardInvoiceInput) (billing.StandardInvoice, error) { return billing.StandardInvoice{}, nil } @@ -1537,6 +1541,10 @@ func (n NoopBillingService) UpdateGatheringInvoice(ctx context.Context, input bi return nil } +func (n NoopBillingService) GetGatheringInvoiceById(ctx context.Context, input billing.GetGatheringInvoiceByIdInput) (billing.GatheringInvoice, error) { + return billing.GatheringInvoice{}, nil +} + // SequenceService methods func (n NoopBillingService) GenerateInvoiceSequenceNumber(ctx context.Context, in billing.SequenceGenerationInput, def billing.SequenceDefinition) (string, error) { return "", nil diff --git a/test/app/stripe/invoice_test.go b/test/app/stripe/invoice_test.go index 3c5ccd4b56..1a9b5f2d21 100644 --- a/test/app/stripe/invoice_test.go +++ b/test/app/stripe/invoice_test.go @@ -1188,7 +1188,7 @@ func (s *StripeInvoiceTestSuite) TestEmptyInvoiceGenerationZeroUsage() { }, }, nil) - invoice, err = s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + invoice, err = s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: invoice.InvoiceID(), EditFn: func(i *billing.StandardInvoice) error { i.Supplier.Name = "ACME Inc. (updated)" diff --git a/test/billing/adapter_test.go b/test/billing/adapter_test.go index bde229e0cd..399e3429c1 100644 --- a/test/billing/adapter_test.go +++ b/test/billing/adapter_test.go @@ -715,7 +715,7 @@ func (s *BillingAdapterTestSuite) TestHardDeleteGatheringInvoiceLines() { err := s.BillingAdapter.HardDeleteGatheringInvoiceLines(ctx, gatheringInvoice.GetInvoiceID(), []string{deletedLine.ID}) s.NoError(err) - gatheringInvoice, err = s.BillingAdapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + gatheringInvoice, err = s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ Invoice: gatheringInvoice.GetInvoiceID(), Expand: billing.GatheringInvoiceExpands{billing.GatheringInvoiceExpandLines}, }) @@ -841,7 +841,7 @@ func (s *BillingAdapterTestSuite) TestHardDeleteGatheringInvoiceLinesNegative() standardInvoice = standardInvoices[0] - gatheringInvoice, err = s.BillingAdapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + gatheringInvoice, err = s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ Invoice: createdPendingLines.Invoice.GetInvoiceID(), Expand: billing.GatheringInvoiceExpands{billing.GatheringInvoiceExpandLines}, }) diff --git a/test/billing/collection_test.go b/test/billing/collection_test.go index f3f8a6142d..ae90530065 100644 --- a/test/billing/collection_test.go +++ b/test/billing/collection_test.go @@ -368,8 +368,8 @@ func (s *CollectionTestSuite) TestCollectionFlowWithFlatFeeEditing() { // When adding a flat fee (in arrears) s.MockStreamingConnector.AddSimpleEvent(apiRequestsTotalFeature.Feature.Key, 1, periodStart.Add(time.Minute*35)) - invoice, err = s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + invoice, err = s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { linePeriod := billing.Period{ Start: periodEnd.Add(time.Hour * 1), @@ -544,8 +544,8 @@ func (s *CollectionTestSuite) TestCollectionFlowWithUBPEditingExtendingCollectio End: lo.Must(time.Parse(time.RFC3339, "2025-01-03T00:00:00Z")), } s.Run("adding a new line extends the collection period", func() { - invoice, err = s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + invoice, err = s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { invoice.Lines.Append(&billing.StandardLine{ StandardLineBase: billing.StandardLineBase{ diff --git a/test/billing/discount_test.go b/test/billing/discount_test.go index 4acf7de97c..59dac6d60e 100644 --- a/test/billing/discount_test.go +++ b/test/billing/discount_test.go @@ -175,7 +175,7 @@ func (s *DiscountsTestSuite) TestCorrelationIDHandling() { }) s.Run("Editing an invoice and adding a new discount generates a new correlation ID", func() { - editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + editedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: draftInvoiceID, EditFn: func(invoice *billing.StandardInvoice) error { line := invoice.Lines.OrEmpty()[0] diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 30e1d86e5a..7160bf53fb 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -770,8 +770,8 @@ func (s *InvoicingTestSuite) TestInvoicingFlow() { require.Equal(s.T(), billing.StandardInvoiceStatusDraftManualApprovalNeeded, invoice.Status) // Let's instruct the sandbox to fail the invoice - _, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + _, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { invoice.Metadata = map[string]string{ appsandbox.TargetPaymentStatusMetadataKey: appsandbox.TargetPaymentStatusFailed, @@ -1740,7 +1740,7 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { }, out[0].Totals) s.Run("update line item", func() { - updatedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + updatedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ Invoice: invoice.InvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { line := invoice.Lines.GetByID(flatPerUnit.ID) @@ -1788,8 +1788,8 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { }) s.Run("invalid update of a line item", func() { - _, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + _, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { line := invoice.Lines.GetByID(flatPerUnit.ID) if line == nil { @@ -1817,8 +1817,8 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { }) s.Run("deleting a valid line item worked", func() { - updatedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: invoice.InvoiceID(), + updatedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: invoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { line := invoice.Lines.GetByID(flatPerUnit.ID) if line == nil { @@ -4331,3 +4331,254 @@ func (s *InvoicingTestSuite) TestGatheringInvoiceEmulation() { require.Equal(s.T(), profile.Supplier.Name, invoice.Supplier.Name) require.Equal(s.T(), sandboxApp.GetID(), invoice.Workflow.Apps.Invoicing.GetID()) } + +func (s *InvoicingTestSuite) TestUpdateInvoice() { + ctx := context.Background() + namespace := s.GetUniqueNamespace("ns-update-invoice") + now := clock.Now().Truncate(time.Second).UTC() + periodStart := now.Add(-48 * time.Hour) + periodEnd := now.Add(-24 * time.Hour) + testLines := []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ + Namespace: namespace, + Period: billing.Period{Start: periodStart, End: periodEnd}, + InvoiceAt: now, + ManagedBy: billing.ManuallyManagedLine, + Name: "line-active", + PerUnitAmount: alpacadecimal.NewFromFloat(100), + PaymentTerm: productcatalog.InArrearsPaymentTerm, + }), + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ + Namespace: namespace, + Period: billing.Period{Start: periodStart, End: periodEnd}, + InvoiceAt: now, + ManagedBy: billing.ManuallyManagedLine, + Name: "line-deleted", + PerUnitAmount: alpacadecimal.NewFromFloat(200), + PaymentTerm: productcatalog.InArrearsPaymentTerm, + }), + } + + sandboxApp := s.InstallSandboxApp(s.T(), namespace) + customerEntity := s.CreateTestCustomer(namespace, "test-update-invoice") + s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID(), WithBillingProfileEditFn(func(profile *billing.CreateProfileInput) { + profile.WorkflowConfig = billing.WorkflowConfig{ + Collection: billing.CollectionConfig{ + Alignment: billing.AlignmentKindSubscription, + }, + Invoicing: billing.InvoicingConfig{ + AutoAdvance: false, + DraftPeriod: lo.Must(datetime.ISODurationString("PT0S").Parse()), + DueAfter: lo.Must(datetime.ISODurationString("P1W").Parse()), + }, + Payment: billing.PaymentConfig{ + CollectionMethod: billing.CollectionMethodChargeAutomatically, + }, + } + })) + + s.Run("gathering invoice", func() { + var gatheringInvoiceID billing.InvoiceID + var activeLineID string + var deletedLineID string + + s.Run("given a gathering invoice with a line and a deleted line", func() { + res, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ + Customer: customerEntity.GetID(), + Currency: currencyx.Code(currency.USD), + Lines: testLines, + }) + require.NoError(s.T(), err) + require.Len(s.T(), res.Lines, 2) + + gatheringInvoiceID = res.Invoice.GetInvoiceID() + activeLineID = res.Lines[0].ID + deletedLineID = res.Lines[1].ID + + err = s.BillingService.UpdateGatheringInvoice(ctx, billing.UpdateGatheringInvoiceInput{ + Invoice: gatheringInvoiceID, + IncludeDeletedLines: true, + EditFn: func(invoice *billing.GatheringInvoice) error { + line, ok := invoice.Lines.GetByID(deletedLineID) + if !ok { + return fmt.Errorf("line[%s] not found", deletedLineID) + } + + line.DeletedAt = lo.ToPtr(clock.Now()) + + return invoice.Lines.ReplaceByID(line) + }, + }) + require.NoError(s.T(), err) + + invoice, err := s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) + require.NoError(s.T(), err) + require.Len(s.T(), invoice.Lines.OrEmpty(), 2) + + deletedLine, ok := invoice.Lines.GetByID(deletedLineID) + require.True(s.T(), ok) + require.NotNil(s.T(), deletedLine.DeletedAt) + }) + + s.Run("when editing using UpdateInvoice", func() { + var sawDeletedLine bool + + updatedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + Invoice: gatheringInvoiceID, + IncludeDeletedLines: true, + EditFn: func(invoice billing.Invoice) (billing.Invoice, error) { + gatheringInvoice, err := invoice.AsGatheringInvoice() + if err != nil { + return billing.Invoice{}, err + } + + deletedLine, ok := gatheringInvoice.Lines.GetByID(deletedLineID) + sawDeletedLine = ok && deletedLine.DeletedAt != nil + + activeLine, ok := gatheringInvoice.Lines.GetByID(activeLineID) + if !ok { + return billing.Invoice{}, fmt.Errorf("line[%s] not found", activeLineID) + } + + activeLine.Name = "gathering-line-active-updated" + + if err := gatheringInvoice.Lines.ReplaceByID(activeLine); err != nil { + return billing.Invoice{}, err + } + + return billing.NewInvoice(gatheringInvoice), nil + }, + }) + require.NoError(s.T(), err) + require.True(s.T(), sawDeletedLine, "edit fn should receive deleted lines when include deleted lines is set") + + updatedGatheringInvoice, err := updatedInvoice.AsGatheringInvoice() + require.NoError(s.T(), err) + + updatedLine, ok := updatedGatheringInvoice.Lines.GetByID(activeLineID) + require.True(s.T(), ok) + require.Equal(s.T(), "gathering-line-active-updated", updatedLine.Name) + + s.Run("then the invoice gets updated", func() { + reloadedInvoice, err := s.BillingService.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + billing.GatheringInvoiceExpandDeletedLines, + }, + }) + require.NoError(s.T(), err) + + reloadedActiveLine, ok := reloadedInvoice.Lines.GetByID(activeLineID) + require.True(s.T(), ok) + require.Equal(s.T(), "gathering-line-active-updated", reloadedActiveLine.Name) + }) + }) + }) + + s.Run("draft invoice with manual approval", func() { + var draftInvoice billing.StandardInvoice + var activeLineID string + var deletedLineID string + + s.Run("given a draft invoice with a line and a deleted line", func() { + pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ + Customer: customerEntity.GetID(), + Currency: currencyx.Code(currency.USD), + Lines: testLines, + }) + require.NoError(s.T(), err) + require.Len(s.T(), pendingLines.Lines, 2) + + invoices, err := s.BillingService.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ + Customer: customerEntity.GetID(), + IncludePendingLines: mo.Some(lo.Map(pendingLines.Lines, func(line billing.GatheringLine, _ int) string { + return line.ID + })), + AsOf: lo.ToPtr(now), + }) + require.NoError(s.T(), err) + require.Len(s.T(), invoices, 1) + draftInvoice = invoices[0] + + require.Equal(s.T(), billing.StandardInvoiceStatusDraftManualApprovalNeeded, draftInvoice.Status) + require.Len(s.T(), draftInvoice.Lines.MustGet(), 2) + + activeLineID = draftInvoice.Lines.MustGet()[0].ID + deletedLineID = draftInvoice.Lines.MustGet()[1].ID + + _, err = s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: draftInvoice.InvoiceID(), + IncludeDeletedLines: true, + EditFn: func(invoice *billing.StandardInvoice) error { + line := invoice.Lines.GetByID(deletedLineID) + if line == nil { + return fmt.Errorf("line[%s] not found", deletedLineID) + } + + line.DeletedAt = lo.ToPtr(clock.Now()) + return nil + }, + }) + require.NoError(s.T(), err) + }) + + s.Run("when editing using UpdateInvoice", func() { + var sawDeletedLine bool + + updatedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ + Invoice: draftInvoice.InvoiceID(), + IncludeDeletedLines: true, + EditFn: func(invoice billing.Invoice) (billing.Invoice, error) { + standardInvoice, err := invoice.AsStandardInvoice() + if err != nil { + return billing.Invoice{}, err + } + + deletedLine := standardInvoice.Lines.GetByID(deletedLineID) + sawDeletedLine = deletedLine != nil && deletedLine.DeletedAt != nil + + activeLine := standardInvoice.Lines.GetByID(activeLineID) + if activeLine == nil { + return billing.Invoice{}, fmt.Errorf("line[%s] not found", activeLineID) + } + + activeLine.Name = "draft-line-active-updated" + + return billing.NewInvoice(standardInvoice), nil + }, + }) + require.NoError(s.T(), err) + require.True(s.T(), sawDeletedLine, "edit fn should receive deleted lines when include deleted lines is set") + + updatedStandardInvoice, err := updatedInvoice.AsStandardInvoice() + require.NoError(s.T(), err) + + updatedLine := updatedStandardInvoice.Lines.GetByID(activeLineID) + require.NotNil(s.T(), updatedLine) + require.Equal(s.T(), "draft-line-active-updated", updatedLine.Name) + + s.Run("then the invoice gets updated", func() { + reloadedInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: draftInvoice.InvoiceID(), + Expand: billing.InvoiceExpandAll.SetDeletedLines(true), + }) + require.NoError(s.T(), err) + + reloadedActiveLine := reloadedInvoice.Lines.GetByID(activeLineID) + require.NotNil(s.T(), reloadedActiveLine) + require.Equal(s.T(), "draft-line-active-updated", reloadedActiveLine.Name) + + reloadedDeletedLine := reloadedInvoice.Lines.GetByID(deletedLineID) + require.NotNil(s.T(), reloadedDeletedLine) + require.NotNil(s.T(), reloadedDeletedLine.DeletedAt) + }) + }) + }) +} diff --git a/test/billing/suite.go b/test/billing/suite.go index fbd496b5a6..668ded4ab9 100644 --- a/test/billing/suite.go +++ b/test/billing/suite.go @@ -312,9 +312,26 @@ func (s *BaseSuite) TearDownSuite() { s.TestDB.PGDriver.Close() } -func (s *BaseSuite) DebugDumpInvoice(h string, i billing.StandardInvoice) { +func (s *BaseSuite) DebugDumpInvoice(h string, i billing.GenericInvoiceReader) { s.T().Log(h) + invoice := i.AsInvoice() + switch invoice.Type() { + case billing.InvoiceTypeStandard: + standardInvoice, err := invoice.AsStandardInvoice() + s.NoError(err) + + s.DebugDumpStandardInvoice(h, standardInvoice) + case billing.InvoiceTypeGathering: + gatheringInvoice, err := invoice.AsGatheringInvoice() + s.NoError(err) + s.DebugDumpGatheringInvoice(h, gatheringInvoice) + default: + s.Fail("invalid invoice type: %s", invoice.Type()) + } +} + +func (s *BaseSuite) DebugDumpStandardInvoice(h string, i billing.StandardInvoice) { l := i.Lines.OrEmpty() slices.SortFunc(l, func(l1, l2 *billing.StandardLine) int { @@ -347,6 +364,37 @@ func (s *BaseSuite) DebugDumpInvoice(h string, i billing.StandardInvoice) { } } +func (s *BaseSuite) DebugDumpGatheringInvoice(h string, i billing.GatheringInvoice) { + l := i.Lines.OrEmpty() + + slices.SortFunc(l, func(l1, l2 billing.GatheringLine) int { + if l1.ServicePeriod.From.Before(l2.ServicePeriod.From) { + return -1 + } else if l1.ServicePeriod.From.After(l2.ServicePeriod.From) { + return 1 + } + return 0 + }) + + for _, line := range i.Lines.OrEmpty() { + deleted := "" + if line.DeletedAt != nil { + deleted = " (deleted)" + } + + priceJson, err := json.Marshal(&line.Price) + s.NoError(err) + + s.T().Logf("usage[%s..%s] childUniqueReferenceID: %s, invoiceAt: %s, qty: N/A, price: %s (total=N/A) %s\n", + line.ServicePeriod.From.Format(time.RFC3339), + line.ServicePeriod.To.Format(time.RFC3339), + lo.FromPtrOr(line.ChildUniqueReferenceID, "null"), + line.InvoiceAt.Format(time.RFC3339), + string(priceJson), + deleted) + } +} + type DraftInvoiceInput struct { Namespace string Customer *customer.Customer diff --git a/test/billing/tax_test.go b/test/billing/tax_test.go index 7cfd908327..378083d178 100644 --- a/test/billing/tax_test.go +++ b/test/billing/tax_test.go @@ -121,8 +121,8 @@ func (s *InvoicingTaxTestSuite) TestDefaultTaxConfigProfileSnapshotting() { s.Nil(draftInvoice.Workflow.Config.Invoicing.DefaultTaxConfig) // let's update the invoice - updatedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ - Invoice: draftInvoice.InvoiceID(), + updatedInvoice, err := s.BillingService.UpdateStandardInvoice(ctx, billing.UpdateStandardInvoiceInput{ + Invoice: draftInvoice.GetInvoiceID(), EditFn: func(invoice *billing.StandardInvoice) error { invoice.Workflow.Config.Invoicing.DefaultTaxConfig = &productcatalog.TaxConfig{ Behavior: lo.ToPtr(productcatalog.InclusiveTaxBehavior),