depressed-bench

joined 11 months ago
[–] depressed-bench@alien.top 1 points 9 months ago

Can you show ordering equivariance of the single matrix with the two matrices?

This form of Attention much be equivariant with respect to token order, eg

attn(ABCD) == rot2(attn(rot2(ABCD))) == rot2(attn(CDAB)) 

I am using rot here for token rotation.

[–] depressed-bench@alien.top 1 points 10 months ago

I will give a difference answer, systems that do online learning are certainly not deterministic in the common sense of the world as their internal changes based on non deterministic behaviour.

Systems that rely on noise generation via non deterministic processes are also non deterministic.

This non determinism is rooted in the change of parts of the state or the input, but for identical state and inputs, the systems are deterministic as long as no bitflips or quantum effects occur in the silicon.

[–] depressed-bench@alien.top 1 points 10 months ago

I will give a difference answer, systems that do online learning are certainly not deterministic in the common sense of the world as their internal changes based on non deterministic behaviour.

Systems that rely on noise generation via non deterministic processes are also non deterministic.

This non determinism is rooted in the change of parts of the state or the input, but for identical state and inputs, the systems are deterministic as long as no bitflips or quantum effects occur in the silicon.

[–] depressed-bench@alien.top 1 points 10 months ago

That checks out tbh. I have seen stuff >.>

[–] depressed-bench@alien.top 1 points 10 months ago (2 children)

Hey, that's what I am seeing in r/experienceddevs :)

[–] depressed-bench@alien.top 1 points 10 months ago

I have not gone into the internals, but from the documentation they seem to be behaving in a very similar manner in the sense that they both do symbolic execution by sending a tracer object, recording what’s happening, and then emitting compiled code.

JAX specifically takes into consideration shapes, and it seems that torch does it as well. Both might do more jitting if it’s necessary, eg different input shapes.

I have the hunch that Jax should be faster because it relies on XLA primitives, eg xla-backed for-loops and scans, and I am not sure how PT handles that given it most likely tries to remain backwards compatible and rely on python semantics over some compiler primitives.