Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions openmeter/app/custominvoicing/httpdriver/custominvoicing.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (h *handler) DraftSyncronized() DraftSyncronizedHandler {
return DraftSyncronizedResponse{}, err
}

return billinghttpdriver.MapInvoiceToAPI(invoice)
return billinghttpdriver.MapStandardInvoiceToAPI(invoice)
},
commonhttp.JSONResponseEncoderWithStatus[DraftSyncronizedResponse](http.StatusOK),
httptransport.AppendOptions(
Expand Down Expand Up @@ -104,7 +104,7 @@ func (h *handler) IssuingSyncronized() IssuingSyncronizedHandler {
return IssuingSyncronizedResponse{}, err
}

return billinghttpdriver.MapInvoiceToAPI(invoice)
return billinghttpdriver.MapStandardInvoiceToAPI(invoice)
},
commonhttp.JSONResponseEncoderWithStatus[IssuingSyncronizedResponse](http.StatusOK),
httptransport.AppendOptions(
Expand Down Expand Up @@ -160,7 +160,7 @@ func (h *handler) UpdatePaymentStatus() UpdatePaymentStatusHandler {
return UpdatePaymentStatusResponse{}, err
}

return billinghttpdriver.MapInvoiceToAPI(invoice)
return billinghttpdriver.MapStandardInvoiceToAPI(invoice)
},
commonhttp.JSONResponseEncoderWithStatus[UpdatePaymentStatusResponse](http.StatusOK),
httptransport.AppendOptions(
Expand Down
2 changes: 1 addition & 1 deletion openmeter/app/stripe/service/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func stripeErrorToValidationError(stripeErr *stripe.Error) error {

// getInvoiceByStripeID retrieves an invoice by its stripe ID, it returns nil if the invoice is not found (thus not managed by the app)
func (s *Service) getInvoiceByStripeID(ctx context.Context, appID app.AppID, stripeInvoiceID string) (*billing.StandardInvoice, error) {
invoices, err := s.billingService.ListInvoices(ctx, billing.ListInvoicesInput{
invoices, err := s.billingService.ListStandardInvoices(ctx, billing.ListStandardInvoicesInput{
Namespaces: []string{appID.Namespace},
ExternalIDs: &billing.ListInvoicesExternalIDFilter{
Type: billing.InvoicingExternalIDType,
Expand Down
8 changes: 5 additions & 3 deletions openmeter/billing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ type InvoiceLineAdapter interface {

type InvoiceAdapter interface {
CreateInvoice(ctx context.Context, input CreateInvoiceAdapterInput) (CreateInvoiceAdapterRespone, error)
GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (StandardInvoice, error)
DeleteGatheringInvoices(ctx context.Context, input DeleteGatheringInvoicesInput) error
GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (Invoice, error)
ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error)
AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error)
UpdateInvoice(ctx context.Context, input UpdateInvoiceAdapterInput) (StandardInvoice, error)

GetInvoiceOwnership(ctx context.Context, input GetInvoiceOwnershipAdapterInput) (GetOwnershipAdapterResponse, error)

GetInvoiceType(ctx context.Context, input GetInvoiceTypeAdapterInput) (InvoiceType, error)
}

type GatheringInvoiceAdapter interface {
CreateGatheringInvoice(ctx context.Context, input CreateGatheringInvoiceAdapterInput) (GatheringInvoice, error)
UpdateGatheringInvoice(ctx context.Context, input UpdateGatheringInvoiceAdapterInput) error
DeleteGatheringInvoice(ctx context.Context, input DeleteGatheringInvoiceAdapterInput) error
GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error)
// TODO: remove
// GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error)
ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error)

HardDeleteGatheringInvoiceLines(ctx context.Context, invoiceID InvoiceID, lineIDs []string) error
Expand Down
14 changes: 13 additions & 1 deletion openmeter/billing/adapter/gatheringinvoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ func (a *adapter) ListGatheringInvoices(ctx context.Context, input billing.ListG
query = query.Where(billinginvoice.CurrencyIn(input.Currencies...))
}

if len(input.IDs) > 0 {
query = query.Where(billinginvoice.IDIn(input.IDs...))
}

if input.NextCollectionAtBeforeOrEqual != nil {
query = query.Where(billinginvoice.CollectionAtLTE(*input.NextCollectionAtBeforeOrEqual))
}

order := entutils.GetOrdering(sortx.OrderDefault)
if !input.Order.IsDefaultValue() {
order = entutils.GetOrdering(input.Order)
Expand Down Expand Up @@ -303,6 +311,7 @@ func (a *adapter) DeleteGatheringInvoice(ctx context.Context, input billing.Dele
invoice, err := tx.db.BillingInvoice.Query().
Where(billinginvoice.ID(input.ID)).
Where(billinginvoice.Namespace(input.Namespace)).
Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)).
Only(ctx)
if err != nil {
return err
Expand All @@ -322,6 +331,8 @@ func (a *adapter) DeleteGatheringInvoice(ctx context.Context, input billing.Dele
Where(billinginvoice.ID(input.ID)).
Where(billinginvoice.Namespace(input.Namespace)).
SetDeletedAt(clock.Now()).
ClearPeriodStart().
ClearPeriodEnd().
Save(ctx)
if err != nil {
return err
Expand Down Expand Up @@ -351,7 +362,8 @@ func (a *adapter) GetGatheringInvoiceById(ctx context.Context, input billing.Get
return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GatheringInvoice, error) {
query := tx.db.BillingInvoice.Query().
Where(billinginvoice.ID(input.Invoice.ID)).
Where(billinginvoice.Namespace(input.Invoice.Namespace))
Where(billinginvoice.Namespace(input.Invoice.Namespace)).
Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering))

if input.Expand.Has(billing.GatheringInvoiceExpandLines) {
query = a.expandGatheringInvoiceLines(query, input.Expand)
Expand Down
142 changes: 93 additions & 49 deletions openmeter/billing/adapter/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package billingadapter

import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"

Expand All @@ -20,7 +20,6 @@ import (
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoicevalidationissue"
"github.com/openmeterio/openmeter/openmeter/ent/db/predicate"
"github.com/openmeterio/openmeter/openmeter/streaming"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/models"
Expand All @@ -30,7 +29,7 @@ import (

var _ billing.InvoiceAdapter = (*adapter)(nil)

func (a *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdInput) (billing.StandardInvoice, error) {
func (a *adapter) GetStandardInvoiceById(ctx context.Context, in billing.GetInvoiceByIdInput) (billing.StandardInvoice, error) {
if err := in.Validate(); err != nil {
return billing.StandardInvoice{}, billing.ValidationError{
Err: err,
Expand Down Expand Up @@ -82,36 +81,6 @@ func (a *adapter) expandInvoiceLineItems(query *db.BillingInvoiceQuery, expand b
})
}

func (a *adapter) DeleteGatheringInvoices(ctx context.Context, input billing.DeleteGatheringInvoicesInput) error {
if err := input.Validate(); err != nil {
return billing.ValidationError{
Err: err,
}
}

return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error {
nAffected, err := tx.db.BillingInvoice.Update().
Where(billinginvoice.IDIn(input.InvoiceIDs...)).
Where(billinginvoice.Namespace(input.Namespace)).
Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)).
ClearPeriodStart().
ClearPeriodEnd().
SetDeletedAt(clock.Now()).
Save(ctx)
if err != nil {
return err
}

if nAffected != len(input.InvoiceIDs) {
return billing.ValidationError{
Err: errors.New("invoices failed to delete"),
}
}

return nil
})
}

func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesInput) (billing.ListInvoicesResponse, error) {
if err := input.Validate(); err != nil {
return billing.ListInvoicesResponse{}, billing.ValidationError{
Expand Down Expand Up @@ -161,10 +130,6 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
query = query.Where(billinginvoice.CreatedAtLTE(*input.CreatedBefore))
}

if len(input.ExtendedStatuses) > 0 {
query = query.Where(billinginvoice.StatusIn(input.ExtendedStatuses...))
}

if len(input.IDs) > 0 {
query = query.Where(billinginvoice.IDIn(input.IDs...))
}
Expand All @@ -173,14 +138,46 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
query = query.Where(billinginvoice.DeletedAtIsNil())
}

if len(input.Statuses) > 0 {
query = query.Where(func(s *sql.Selector) {
s.Where(sql.Or(
lo.Map(input.Statuses, func(status string, _ int) *sql.Predicate {
return sql.Like(billinginvoice.FieldStatus, status+"%")
})...,
))
})
if len(input.InvoiceTypes) > 0 {
includeStandard := slices.Contains(input.InvoiceTypes, billing.InvoiceTypeStandard) || len(input.InvoiceTypes) == 0
includeGathering := slices.Contains(input.InvoiceTypes, billing.InvoiceTypeGathering) || len(input.InvoiceTypes) == 0

// Hack: right now we are using the statuses to filter for standard and gathering invoices
queries := []predicate.BillingInvoice{}
if includeStandard {
queries = append(queries, func(s *sql.Selector) {
predicates := []*sql.Predicate{
sql.Not(sql.EQ(billinginvoice.FieldStatus, billing.StandardInvoiceStatusGathering)),
}

if len(input.StandardInvoiceStatuses) > 0 {
predicates = append(predicates, sql.Or(
lo.Map(input.StandardInvoiceStatuses, func(status string, _ int) *sql.Predicate {
return sql.Like(billinginvoice.FieldStatus, status+"%")
})...,
))
}

if len(input.StandardInvoiceExtendedStatuses) > 0 {
predicates = append(predicates, sql.In(billinginvoice.FieldStatus, lo.Map(input.StandardInvoiceExtendedStatuses, func(status billing.StandardInvoiceStatus, _ int) any {
return status
})...))
}

s.Where(sql.And(predicates...))
})
}

if includeGathering {
queries = append(queries, func(s *sql.Selector) {
s.Where(
sql.EQ(billinginvoice.FieldStatus, billing.StandardInvoiceStatusGathering))
})
}

query = query.Where(billinginvoice.Or(
queries...,
))
}

if len(input.Currencies) > 0 {
Expand Down Expand Up @@ -248,7 +245,7 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
query = query.Order(billinginvoice.ByCreatedAt(order...))
}

response := pagination.Result[billing.StandardInvoice]{
response := pagination.Result[billing.Invoice]{
Page: input.Page,
}

Expand All @@ -257,14 +254,33 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
return response, err
}

result := make([]billing.StandardInvoice, 0, len(paged.Items))
result := make([]billing.Invoice, 0, len(paged.Items))
for _, invoice := range paged.Items {
mapped, err := tx.mapStandardInvoiceFromDB(ctx, invoice, input.Expand)
if invoice.Status == billing.StandardInvoiceStatusGathering {
expand := billing.GatheringInvoiceExpands{}
if input.Expand.Lines {
expand = expand.With(billing.GatheringInvoiceExpandLines)
}

if input.Expand.DeletedLines {
expand = expand.With(billing.GatheringInvoiceExpandDeletedLines)
}

gatheredMapped, err := tx.mapGatheringInvoiceFromDB(invoice, expand)
if err != nil {
return response, err
}

result = append(result, billing.NewInvoice(gatheredMapped))
continue
}

stdMapped, err := tx.mapStandardInvoiceFromDB(ctx, invoice, input.Expand)
if err != nil {
return response, err
}

result = append(result, mapped)
result = append(result, billing.NewInvoice(stdMapped))
}

response.TotalCount = paged.TotalCount
Expand Down Expand Up @@ -866,3 +882,31 @@ func (a *adapter) IsAppUsed(ctx context.Context, appID app.AppID) error {

return nil
}

func (a *adapter) GetInvoiceType(ctx context.Context, input billing.GetInvoiceTypeAdapterInput) (billing.InvoiceType, error) {
if err := input.Validate(); err != nil {
return "", err
}

return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.InvoiceType, error) {
invoice, err := tx.db.BillingInvoice.Query().
Where(billinginvoice.ID(input.ID)).
Where(billinginvoice.Namespace(input.Namespace)).
Only(ctx)
if err != nil {
if db.IsNotFound(err) {
return "", billing.NotFoundError{
Err: fmt.Errorf("invoice not found: %w", err),
}
}

return "", err
}

if invoice.Status == billing.StandardInvoiceStatusGathering {
return billing.InvoiceTypeGathering, nil
}

return billing.InvoiceTypeStandard, nil
})
}
Loading
Loading