Skip to content

Commit 0b397f6

Browse files
committed
Do not rely on private Axon.Shape
1 parent 8c94469 commit 0b397f6

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

lib/bumblebee/layers.ex

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,25 @@ defmodule Bumblebee.Layers do
409409
])
410410

411411
kernel_shape = fn input_shape ->
412-
kernel_shape = Axon.Shape.dense_kernel(input_shape, units)
412+
unless Nx.rank(input_shape) >= 2 do
413+
raise ArgumentError,
414+
"input shape must have at least rank 2, got rank" <>
415+
" #{Nx.rank(input_shape)}"
416+
end
413417

414-
# We expect a transposed kernel
415-
kernel_shape
416-
|> Tuple.to_list()
417-
|> Enum.reverse()
418-
|> List.to_tuple()
418+
# Transposed kernel compared to Axon.dense/3.
419+
{units, elem(input_shape, Nx.rank(input_shape) - 1)}
419420
end
420421

421-
bias_shape = &Axon.Shape.dense_bias(&1, units)
422+
bias_shape = fn input_shape ->
423+
unless Nx.rank(input_shape) >= 2 do
424+
raise ArgumentError,
425+
"input shape must have at least rank 2, got rank" <>
426+
" #{Nx.rank(input_shape)}"
427+
end
428+
429+
{units}
430+
end
422431

423432
kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
424433

@@ -1176,10 +1185,13 @@ defmodule Bumblebee.Layers do
11761185
"expected :upcast to be either :all or :normalization, got: #{other}"
11771186
end
11781187

1179-
weight =
1180-
Axon.param("weight", &Axon.Shape.norm_param(&1, opts[:channel_index]),
1181-
initializer: opts[:initializer]
1182-
)
1188+
weight_shape = fn input_shape ->
1189+
names = List.duplicate(nil, Nx.rank(input_shape))
1190+
axis = Nx.Shape.normalize_axis(input_shape, opts[:channel_index], names)
1191+
{elem(input_shape, axis)}
1192+
end
1193+
1194+
weight = Axon.param("weight", weight_shape, initializer: opts[:initializer])
11831195

11841196
Axon.layer(impl, [input, weight],
11851197
name: opts[:name],

0 commit comments

Comments
 (0)