@@ -409,16 +409,25 @@ defmodule Bumblebee.Layers do
409409 ] )
410410
411411 kernel_shape = fn input_shape ->
412- kernel_shape = Axon.Shape . dense_kernel ( input_shape , units )
412+ unless Nx . rank ( input_shape ) >= 2 do
413+ raise ArgumentError ,
414+ "input shape must have at least rank 2, got rank" <>
415+ " #{ Nx . rank ( input_shape ) } "
416+ end
413417
414- # We expect a transposed kernel
415- kernel_shape
416- |> Tuple . to_list ( )
417- |> Enum . reverse ( )
418- |> List . to_tuple ( )
418+ # Transposed kernel compared to Axon.dense/3.
419+ { units , elem ( input_shape , Nx . rank ( input_shape ) - 1 ) }
419420 end
420421
421- bias_shape = & Axon.Shape . dense_bias ( & 1 , units )
422+ bias_shape = fn input_shape ->
423+ unless Nx . rank ( input_shape ) >= 2 do
424+ raise ArgumentError ,
425+ "input shape must have at least rank 2, got rank" <>
426+ " #{ Nx . rank ( input_shape ) } "
427+ end
428+
429+ { units }
430+ end
422431
423432 kernel = Axon . param ( "kernel" , kernel_shape , initializer: opts [ :kernel_initializer ] )
424433
@@ -1176,10 +1185,13 @@ defmodule Bumblebee.Layers do
11761185 "expected :upcast to be either :all or :normalization, got: #{ other } "
11771186 end
11781187
1179- weight =
1180- Axon . param ( "weight" , & Axon.Shape . norm_param ( & 1 , opts [ :channel_index ] ) ,
1181- initializer: opts [ :initializer ]
1182- )
1188+ weight_shape = fn input_shape ->
1189+ names = List . duplicate ( nil , Nx . rank ( input_shape ) )
1190+ axis = Nx.Shape . normalize_axis ( input_shape , opts [ :channel_index ] , names )
1191+ { elem ( input_shape , axis ) }
1192+ end
1193+
1194+ weight = Axon . param ( "weight" , weight_shape , initializer: opts [ :initializer ] )
11831195
11841196 Axon . layer ( impl , [ input , weight ] ,
11851197 name: opts [ :name ] ,
0 commit comments